## 1. Import Modules and Data

Accroding to original paper <a href="https://arxiv.org/abs/1810.04805">BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding</a>, Bert is pretrained on bookcorpus and wikipedia. More info about bookcorpus and wikipedia can be found at huggingface webpages of [bookcorpus](https://huggingface.co/datasets/bookcorpus/bookcorpus) and [wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia).


BERT is pretrained on  Masked Language Model (Mask LM) and Next Sentence Prediction tasks during training (original paper, Section 3.1):

1. Masked LM: The training data generator randomly selects 15% of token positions for prediction. If the i-th token is selected, it is replaced with: (1) the [MASK] token 80% of the time, (2) a random token 10% of the time, or (3) the original i-th token 10% of the time. The model then predicts the original token using cross-entropy loss.

2. Next Sentence Prediction: This is a binary classification task. When selecting sentences A and B for each pre-training example, B is the actual subsequent sentence following A 50% of the time, and the other 50% of the time, B is a randomly chosen sentence from the corpus.

In [None]:
import itertools
import os

from data import load_data
import config

tokenizer, bookcorpus_dl = load_data("bookcorpus", loading_ratio=0.1)
_, wikipedia_dl = load_data("wikipedia", loading_ratio=1 / 41)


dataloader_size = len(bookcorpus_dl) + len(wikipedia_dl)
print("bookcorpus_dl size:", len(bookcorpus_dl))
print("wikipedia_dl size:", len(wikipedia_dl))
print("Total Dataloader Size:", dataloader_size)

## 2. Build Model
The key structural difference between Bert and GPT is that Bert does not use a causal mask in its layers. This allows Bert to leverage bidirectional attention, enabling it to capture global dependencies directly. Consequently, Bert's pre-training tasks are fundamentally different from GPT's next-token prediction task.

In [2]:
import torch
from modules.bert import BertForPreTraining

device = torch.device("cuda")
bert = BertForPreTraining(
    vocab_size=tokenizer.vocab_size,
    type_vocab_size=2,
    hidden_size=config.hidden_size,
    max_len=config.max_len,
    num_hidden_layers=config.num_layers,
    num_attention_heads=config.attention_heads,
    intermediate_size=config.intermediate_size,
    dropout=config.dropout,
    pad_token_idx=tokenizer.pad_token_id,
).to(device)

## 3. Pretrain Model


In [None]:
import torch
import torch.nn as nn
from tqdm.notebook import tqdm


def train(epoch, model, optimizer, scheduler):
    model.train()
    avg_loss = 0.0
    total_correct = 0
    total_element = 0

    optimizer.zero_grad()
    dataloader = itertools.chain(bookcorpus_dl, wikipedia_dl)

    for batch in tqdm(
        dataloader, desc=f"Training Epoch {epoch}", total=dataloader_size
    ):
        input_ids = batch.input_ids.to(device)
        attention_mask = (input_ids != tokenizer.pad_token_id).bool().to(device)
        token_type_ids = batch.token_type_ids.to(device)
        labels = batch.labels.to(device)
        is_next = batch.is_next.to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=labels,
            next_sentence_label=is_next,
        )
        total_loss, prediction_scores, seq_relationship_score = outputs

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        scheduler.step()

        # next sentence prediction accuracy
        correct = (
            seq_relationship_score.argmax(dim=-1).eq(is_next.squeeze(-1)).sum().item()
        )

        avg_loss += total_loss.item()
        total_correct += correct
        total_element += is_next.nelement()

    print(
        "EP%d_train, avg_loss=" % (epoch),
        avg_loss / len(dataloader),
        "NSP_acc=",
        total_correct * 100.0 / total_element,
    )

    return avg_loss

### 3.2 Train loop
Simple test of pretraining process.

In [None]:
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

print("Training Start")

def training_loop(restore_epoch=-1):
    optimizer = AdamW(
        bert.parameters(),
        lr=config.PretrainConfig.lr,
        weight_decay=config.PretrainConfig.weight_decay,
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config.PretrainConfig.warmup_steps,
        num_training_steps=dataloader_size * config.PretrainConfig.n_epoch,
    )

    restore_ckpt_path = config.checkpoint_dir / f"gpt_{restore_epoch}.pth"
    if restore_epoch != -1 and os.path.exists(restore_ckpt_path):
        ckpt = torch.load(restore_ckpt_path)
        assert ckpt["epoch"] == restore_epoch
        bert.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["scheduler"])
    else:
        restore_epoch = 0

    for epoch in range(restore_epoch, config.PretrainConfig.n_epoch):
        avg_train_loss = train(epoch + 1, bert, optimizer, scheduler)
        print(
            f"Epoch {epoch + 1}/{config.PretrainConfig.n_epoch}, Training Loss: {avg_train_loss: .4f}"
        )

        checkpoint_path = config.checkpoint_dir / f"bert_{epoch + 1}.pth"
        torch.save(
            {
                "epoch": epoch + 1,
                "model": bert.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            },
            checkpoint_path,
        )


training_loop()