## 1. Import Modules and Data

Accroding to original paper <a href="https://arxiv.org/abs/1810.04805">BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding</a>, Bert is pretrained on bookcorpus and wikipedia. More info about bookcorpus and wikipedia can be found at huggingface webpages of [bookcorpus](https://huggingface.co/datasets/bookcorpus/bookcorpus) and [wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia).


BERT is pretrained on  Masked Language Model (Mask LM) and Next Sentence Prediction tasks during training (original paper, Section 3.1):

1. Masked LM: The training data generator randomly selects 15% of token positions for prediction. If the i-th token is selected, it is replaced with: (1) the [MASK] token 80% of the time, (2) a random token 10% of the time, or (3) the original i-th token 10% of the time. The model then predicts the original token using cross-entropy loss.

2. Next Sentence Prediction: This is a binary classification task. When selecting sentences A and B for each pre-training example, B is the actual subsequent sentence following A 50% of the time, and the other 50% of the time, B is a randomly chosen sentence from the corpus.

In [None]:
import itertools

from data import load_data
import config

tokenizer, bookcorpus_dl = load_data("bookcorpus", loading_ratio=0.1)
_, wikipedia_dl = load_data("wikipedia", loading_ratio=0.01)

train_dl = itertools.chain(bookcorpus_dl, wikipedia_dl)


## 2. Build Model
The key structural difference between Bert and GPT is that Bert does not use a causal mask in its layers. This allows Bert to leverage bidirectional attention, enabling it to capture global dependencies directly. Consequently, Bert's pre-training tasks are fundamentally different from GPT's next-token prediction task.

In [None]:
import torch
from modules.bert import BertForPreTraining

device = torch.device("cuda")
bert = BertForPreTraining(
    vocab_size=tokenizer.vocab_size,
    type_vocab_size=2,
    hidden_size=config.hidden_size,
    max_len=config.max_len,
    num_hidden_layers=config.num_layers,
    num_attention_heads=config.attention_heads,
    intermediate_size=config.intermediate_size,
    dropout=config.dropout,
    pad_token_id=tokenizer.pad_token_id,
).to(device)

## 3. Pretrain Model
Here I define a BERTTrainer class for pretraing settings.

In [5]:
import torch
import torch.nn as nn
import tqdm
from torch.optim import Adam
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from modules import BertForPreTraining

class BERTTrainer:
    """
    BERTTrainer make the pretrained BERT model with two LM training method.

        1. Masked Language Model : 3.3.1 Task #1: Masked LM
        2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction

    please check the details on README.md with simple example.

    """
    def __init__(
            self, 
            bert: BertForPreTraining,
            train_dataloader: DataLoader, 
            test_dataloader: DataLoader = None,
            lr: float = 1e-4, 
            betas=(0.9, 0.999), 
            weight_decay: float = 0.01, 
            warmup_steps=10000,
            total_steps=1000000,
            with_cuda: bool = True, 
            log_freq: int = 10):
        """
        :param bert: BERT model which you want to train
        :param train_dataloader: train dataset data loader
        :param test_dataloader: test dataset data loader [can be None]
        :param lr: learning rate of optimizer
        :param betas: Adam optimizer betas
        :param weight_decay: Adam optimizer weight decay param
        :param with_cuda: traning with cuda
        :param log_freq: logging frequency of the batch iteration
        """

        # Setup cuda device for BERT training, argument -c, --cuda should be true
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        # This BERT model will be saved every epoch
        self.bert = bert
        # Initialize the BERT Language Model, with BERT model
        self.model = bert.to(self.device)

        # Distributed GPU training if CUDA can detect more than 1 GPU
        # if with_cuda and torch.cuda.device_count() > 1:
        #     print("Using %d GPUS for BERT" % torch.cuda.device_count())
        #     self.model = nn.DataParallel(self.model, device_ids=cuda_devices)

        # Setting the train and test data loader
        self.train_data = train_dataloader
        self.test_data = test_dataloader

        # Setting the Adam optimizer with hyper-param
        self.optimizer = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = nn.NLLLoss(ignore_index=0)

        self.log_freq = log_freq

        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        """
        loop over the data_loader for training or testing
        if on train status, backward operation is activated
        and also auto save the model every epoch

        :param epoch: current epoch index
        :param data_loader: torch.utils.data.DataLoader for iteration
        :param train: boolean value of is train or test
        :return: None
        """
        str_code = "train" if train else "test"

        # Setting the tqdm progress bar
        data_iter = tqdm.tqdm(enumerate(data_loader),
                              desc="EP_%s:%d" % (str_code, epoch),
                              total=len(data_loader),
                              bar_format="{l_bar}{r_bar}")

        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        for i, data in data_iter:
            # get batch data
            data = {key: value.to(self.device) for key, value in data.items()}

            outputs = self.model.forward(
                input_ids=data["bert_input"],
                attention_mask=data["bert_attention_mask"],
                token_type_ids=data["segment_label"],
                labels=data["bert_label"],
                next_sentence_label=data["is_next"]
            )
            total_loss, prediction_scores, seq_relationship_score = outputs
            # backward and optimization only in train
            if train:
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()
                self.scheduler.step()

            # next sentence prediction accuracy
            correct = seq_relationship_score.argmax(dim=-1).eq(data["is_next"]).sum().item()
            avg_loss += total_loss.item()
            total_correct += correct
            total_element += data["is_next"].nelement()

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": total_loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))

        print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=",
              total_correct * 100.0 / total_element)

    def save(self, epoch, file_path):
        """
        Saving the current BERT model on file_path

        :param epoch: current epoch number
        :param file_path: model output path which gonna be file_path+"ep%d" % epoch
        :return: final_output_path
        """
        output_path = str(file_path) + "_ep%d.pth" % epoch
        torch.save(self.bert.cpu(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
    
trainer = BERTTrainer(
    bert,
    train_dataloader=train_data_loader, 
    test_dataloader=test_data_loader,       
    lr=config.learning_rate,     
    betas=(config.adam_beta1, config.adam_beta2),
    weight_decay=config.adam_weight_decay,
    warmup_steps=config.warmup_steps,
    total_steps = config.total_steps,
    with_cuda=config.with_cuda, 
    log_freq=config.log_freq
)

Total Parameters: 110106428


### 3.2 Train loop
Simple test of pretraining process.

In [6]:
print("Training Start")
for epoch in range(config.epochs):
    trainer.train(epoch)
    if epoch % config.log_freq == 0:
        trainer.save(epoch, config.trained_path)
    if test_data_loader is not None:
        trainer.test(epoch)

Training Start


EP_train:0:  27%|| 3/11 [00:01<00:02,  3.00it/s]

{'epoch': 0, 'iter': 0, 'avg_loss': 11.511968612670898, 'avg_acc': 100.0, 'loss': 11.511968612670898}


EP_train:0:  82%|| 9/11 [00:01<00:00,  8.97it/s]

{'epoch': 0, 'iter': 5, 'avg_loss': 11.693164825439453, 'avg_acc': 50.0, 'loss': 11.782041549682617}


EP_train:0: 100%|| 11/11 [00:01<00:00,  5.89it/s]


{'epoch': 0, 'iter': 10, 'avg_loss': 11.638592720031738, 'avg_acc': 59.09090909090909, 'loss': 11.53420352935791}
EP0_train, avg_loss= 11.638592720031738 total_acc= 59.09090909090909
EP:0 Model Saved on: /home/bks/lzh/checkpoints/bert_self_trained_ep0.pth


EP_train:1:  27%|| 3/11 [00:00<00:00, 11.52it/s]

{'epoch': 1, 'iter': 0, 'avg_loss': 11.633398056030273, 'avg_acc': 50.0, 'loss': 11.633398056030273}


EP_train:1:  82%|| 9/11 [00:00<00:00, 14.79it/s]

{'epoch': 1, 'iter': 5, 'avg_loss': 11.564441521962484, 'avg_acc': 66.66666666666666, 'loss': 11.521570205688477}


EP_train:1: 100%|| 11/11 [00:00<00:00, 11.90it/s]

{'epoch': 1, 'iter': 10, 'avg_loss': 11.480820222334428, 'avg_acc': 54.54545454545454, 'loss': 11.199502944946289}
EP1_train, avg_loss= 11.480820222334428 total_acc= 54.54545454545455



EP_train:2:  27%|| 3/11 [00:00<00:00, 13.08it/s]

{'epoch': 2, 'iter': 0, 'avg_loss': 11.45975399017334, 'avg_acc': 0.0, 'loss': 11.45975399017334}


EP_train:2:  82%|| 9/11 [00:00<00:00, 15.05it/s]

{'epoch': 2, 'iter': 5, 'avg_loss': 11.160041014353434, 'avg_acc': 50.0, 'loss': 11.001148223876953}


EP_train:2: 100%|| 11/11 [00:00<00:00, 12.34it/s]

{'epoch': 2, 'iter': 10, 'avg_loss': 11.023141774264248, 'avg_acc': 54.54545454545454, 'loss': 10.763666152954102}
EP2_train, avg_loss= 11.023141774264248 total_acc= 54.54545454545455



EP_train:3:  18%|| 2/11 [00:00<00:00, 12.31it/s]

{'epoch': 3, 'iter': 0, 'avg_loss': 10.7632417678833, 'avg_acc': 50.0, 'loss': 10.7632417678833}


EP_train:3:  73%|| 8/11 [00:00<00:00, 15.04it/s]

{'epoch': 3, 'iter': 5, 'avg_loss': 10.52497657140096, 'avg_acc': 58.333333333333336, 'loss': 10.387877464294434}


EP_train:3: 100%|| 11/11 [00:00<00:00, 11.74it/s]

{'epoch': 3, 'iter': 10, 'avg_loss': 10.382498654452236, 'avg_acc': 45.45454545454545, 'loss': 10.135360717773438}
EP3_train, avg_loss= 10.382498654452236 total_acc= 45.45454545454545



EP_train:4:  27%|| 3/11 [00:00<00:00, 11.53it/s]

{'epoch': 4, 'iter': 0, 'avg_loss': 9.780165672302246, 'avg_acc': 100.0, 'loss': 9.780165672302246}


EP_train:4:  82%|| 9/11 [00:00<00:00, 14.84it/s]

{'epoch': 4, 'iter': 5, 'avg_loss': 9.686190605163574, 'avg_acc': 75.0, 'loss': 9.539986610412598}


EP_train:4: 100%|| 11/11 [00:00<00:00, 11.91it/s]

{'epoch': 4, 'iter': 10, 'avg_loss': 9.480205015702682, 'avg_acc': 59.09090909090909, 'loss': 9.070290565490723}
EP4_train, avg_loss= 9.480205015702682 total_acc= 59.09090909090909



