In [1]:
# import packages
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import re
from random import randrange, shuffle, random
import math
import datasets
from datasets import load_dataset
from tqdm import tqdm
import spacy
import time

In [None]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Load dataset

The **`wikimedia/wikipedia`** dataset is a large open-source text collection created from Wikipedia articles and released through the Hugging Face `datasets` library. The version **`20231101.en`** is based on the English Wikipedia dump from November 1, 2023, and contains millions of articles covering a wide range of topics such as history, science, technology, culture, and geography. Each record typically includes the article title, full cleaned text, a unique ID, and a link to the original page. It is widely used for natural language processing tasks like language modeling, search, summarization, and question answering.

In [None]:
dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split="train[:1500]")

In [None]:
print(f"{len(dataset['text'])} lists")

In [None]:
ts = 0
for i in dataset:
    ts += len(i["text"].split("."))
print(f"{ts} sentences")

In [None]:
print(f"{sum([len(i.split()) for i in dataset['text']])} words")

# Preprocessing data

In [None]:
def process_data(dataset, max_articles=100000, max_vocab=30000):
    # Load spaCy with only the sentencizer (no parser, no NER)
    nlp = spacy.blank("en")
    nlp.add_pipe(
        "sentencizer"
    )  # A simple pipeline component to allow custom sentence boundary detection logic that doesn't require the dependency parse

    word_counts = {}  # hold frq of all word
    all_tokenized_sentences = []  # hold all token of each list/article

    # Iterate over the dataset split
    for sen in dataset:
        # loop over each lists/articles
        text = sen["text"]
        doc = nlp(text)  # process each article

        for sent in doc.sents:
            sent = sent.text.lower()
            sent = re.sub(r"[.,!?\\-]", " ", sent)  # clean sentence
            tks = sent.split()  # split the sentence to tokens
            if len(tks) == 0:
                continue
            all_tokenized_sentences.append(tks)  # add all tokens of each
            for w in tks:
                # initial vocab
                word_counts[w] = word_counts.get(w, 0) + 1  # count each token

    word2key = {"[PAD]": 0, "[CLS]": 1, "[SEP]": 2, "[MASK]": 3, "[UNK]": 4}
    # Build vocabulary: keep top `max_vocab - 4` words (reserve special tokens)
    sorted_word_counts = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
    top_words = [w for w, k in sorted_word_counts[: max_vocab - len(word2key)]]  # vocab

    for idx, j in enumerate(top_words):
        # add token in word2key after special words (0-4)
        word2key[j] = idx + 5
        vocab_size = len(word2key)  # <= max_vocab

    key2word = {k: w for w, k in word2key.items()}

    # Convert sentences to integer IDs
    tokenized_ids = []

    for alt in all_tokenized_sentences:
        ids = [word2key.get(w, word2key["[UNK]"]) for w in alt]
        tokenized_ids.append(ids)  # ids of each sentence list

    return tokenized_ids, word2key, key2word, vocab_size

## Data loader

In [None]:
def make_bert_batch(
    tokenized_sentences, word2id, max_len=128, max_pred=20, batch_size=32
):
    batch = []
    positive = negative = 0
    n_sentences = len(tokenized_sentences)

    while (positive < batch_size // 2) or (negative < batch_size // 2):
        # Randomly select two sentences (make copies so we don't modify originals)
        i, j = randrange(n_sentences), randrange(n_sentences)
        tokens_a = tokenized_sentences[i][:]
        tokens_b = tokenized_sentences[j][:]

        # Truncate long pairs so total length <= max_len - 3 ([CLS], [SEP], [SEP])
        while len(tokens_a) + len(tokens_b) > max_len - 3:
            # remove from the longer sequence (simple heuristic)
            if len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()

        # Build input sequence
        input_ids = (
            [word2id["[CLS]"]]
            + tokens_a
            + [word2id["[SEP]"]]
            + tokens_b
            + [word2id["[SEP]"]]
        )

        # Segment ids: 0 for a, 1 for b
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        # Determine number of predictions (masked tokens)
        n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))

        # Candidate positions for masking (exclude special tokens)
        candidates = [
            idx
            for idx, token in enumerate(input_ids)
            if token not in (word2id["[CLS]"], word2id["[SEP]"])
        ]
        shuffle(candidates)

        masked_tokens, masked_pos = [], []
        for pos in candidates[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])

            r = random()
            if r < 0.8:
                # 80% replace with [MASK]
                input_ids[pos] = word2id["[MASK]"]
            elif r < 0.9:
                # 10% replace with random token id (avoid special tokens)
                input_ids[pos] = int(np.random.randint(4, len(word2id)))
            # else: 10% keep original

        # Pad input_ids and segment_ids to max_len (PAD token assumed id 0)
        n_pad = max_len - len(input_ids)
        if n_pad > 0:
            input_ids.extend([0] * n_pad)
            segment_ids.extend([0] * n_pad)

        # Pad masked tokens/positions up to max_pred
        if max_pred > n_pred:
            n_pad_mask = max_pred - n_pred
            # Use -100 so CrossEntropyLoss(ignore_index=-100) will ignore these
            masked_tokens.extend([-100] * n_pad_mask)
            masked_pos.extend(
                [0] * n_pad_mask
            )  # position 0 is fine because token is ignored in loss

        # NSP label: 1 if next (consecutive), else 0
        is_next = 1 if (i + 1 == j) else 0

        if is_next and positive < batch_size // 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, is_next])
            positive += 1
        elif (not is_next) and negative < batch_size // 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, is_next])
            negative += 1

    return batch

# BERT model

In [None]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model=768, max_len=512, n_segments=2):
        super(Embedding, self).__init__()
        # embedding matrix: maps token ids -> vectors
        self.tok_embed = nn.Embedding(
            vocab_size, d_model
        )  # (V, D) ; lookup on input (B, L) -> (B, L, D)
        # positional embedding matrix: maps position idx -> vectors
        self.pos_embed = nn.Embedding(
            max_len, d_model
        )  # (M, D) ; pos embedding for positions 0..L-1 -> (B, L, D)

        # segment (token type) embedding: maps segment id -> vectors
        self.seg_embed = nn.Embedding(
            n_segments, d_model
        )  # (S, D) ; seg lookup on seg ids (B, L) -> (B, L, D)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
        pos = pos.unsqueeze(0).expand_as(x)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)


def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
    return pad_attn_mask.expand(batch_size, len_q, len_k)


class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        """
        x: input ids shape (B, L)
        seg: segment ids shape (B, L)
        """
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(d_k)
        # K.transpose(-1, -2): (B, H, d_k, L_k)
        # Q:                   (B, H, L_q, d_k)
        # matmul result:       (B, H, L_q, L_k)
        # scores:              (B, H, L_q, L_k)
        scores.masked_fill_(attn_mask, -1e9)
        # attn_mask:           (B, H, L_q, L_k)
        # scores unchanged shape: (B, H, L_q, L_k)
        attn = nn.Softmax(dim=-1)(scores)  # (B, H, L_q, L_k)
        context = torch.matmul(attn, V)  # (B, H, L_q, d_v)
        return context, attn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=768, n_heads=12, d_k=64, d_v=64):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads  # H
        self.d_k = d_k  # d_k
        self.d_v = d_v  # d_v

        # linear projections for Q, K, V
        # W_Q: projects (B, L, D) -> (B, L, H * d_k)
        # weight shape: (D, H*d_k) ; bias shape: (H*d_k,)
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)

        self.linear = nn.Linear(n_heads * d_v, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, Q, K, V, attn_mask):
        # Q, K, V: (B, L, D)
        residual = Q
        batch_size = Q.size(0)  # B

        q_s = (
            self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        )  # q_s: (B, H, L, d_k)
        k_s = (
            self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        )  # k_s: (B, H, L, d_k)
        v_s = (
            self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
        )  # v_s: (B, H, L, d_v)

        attn_mask = attn_mask.unsqueeze(1).repeat(
            1, self.n_heads, 1, 1
        )  # (B, H, L_q, L_k)

        context, attn = ScaledDotProductAttention()(
            q_s, k_s, v_s, attn_mask
        )  # context: (B, H, L, d_v) ; attn: (B, H, L, L)

        context = (
            context.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.n_heads * self.d_v)
        )  # (B, L, H * d_v)

        output = self.linear(context)

        return (
            self.norm(output + residual),
            attn,
        )  # output: (B, L, D), attn: (B, H, L, L)


class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model=768, d_ff=3072):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(torch.nn.functional.gelu(self.fc1(x)))


class EncoderLayer(nn.Module):
    def __init__(self, d_model=768, n_heads=12, d_ff=3072):
        super(EncoderLayer, self).__init__()
        # self-attention sublayer
        self.enc_self_attn = MultiHeadAttention(d_model, n_heads)
        # feed-forward sublayer
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(
            enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask
        )
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs, attn


class BERT(nn.Module):
    def __init__(self, vocab_size, d_model=768, n_layers=12, n_heads=12, max_len=512):
        super(BERT, self).__init__()
        self.d_model = d_model
        self.n_layers = n_layers
        self.embedding = Embedding(vocab_size, d_model, max_len)  #  (B, L, D)
        self.layers = nn.ModuleList(
            [EncoderLayer(d_model, n_heads) for _ in range(n_layers)]
        )

        # MLM head
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)

        # NSP head
        self.nsp_classifier = nn.Linear(d_model, 2)

        # MLM decoder (tied to embedding)
        self.decoder = nn.Linear(d_model, vocab_size, bias=False)
        self.decoder.weight = self.embedding.tok_embed.weight
        self.decoder_bias = nn.Parameter(torch.zeros(vocab_size))

    def forward(self, input_ids, segment_ids, masked_pos=None):
        # Embedding
        output = self.embedding(input_ids, segment_ids)  # (B, L, D)

        # Attention mask
        attn_mask = get_attn_pad_mask(input_ids, input_ids)  # (B, L, L)

        # Encoder layers
        for layer in self.layers:
            output, _ = layer(output, attn_mask)

        # NSP prediction (using [CLS] token)
        cls_output = output[:, 0, :]  # (B, D)
        nsp_logits = self.nsp_classifier(cls_output)  # (B, 2)

        # MLM prediction (if masked_pos provided)
        if masked_pos is not None:
            masked_pos = masked_pos[:, :, None].expand(
                -1, -1, self.d_model
            )  # (B, N_mask, D)
            h_masked = torch.gather(output, 1, masked_pos)  # (B, N_mask, D)
            h_masked = self.norm(
                torch.nn.functional.gelu(self.linear(h_masked))
            )  # (B, N_mask, D)
            mlm_logits = self.decoder(h_masked) + self.decoder_bias  # (B, N_mask, V)
            return mlm_logits, nsp_logits

        return output, nsp_logits

# Training

In [None]:
tokenized_ids, word2key, key2word, vocab_size = process_data(dataset)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

total_epoches = 100

d_model = 512
n_layers = 8
n_heads = 8
max_len = 128
n_segments = 2


def train_bert():

    best_model = None
    best_loss = float("inf")
    total_loss = 0.0

    # Initialize model
    model = BERT(
        vocab_size=vocab_size,
        d_model=d_model,
        n_layers=n_layers,
        n_heads=n_heads,
        max_len=max_len,
    )
    model.to(device)

    # Training setup
    criterion_mlm = nn.CrossEntropyLoss()
    criterion_nsp = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    n_epochs = total_epoches
    batch_size = 32

    strat_time = time.perf_counter()

    print("Starting training...")
    for epoch in range(n_epochs):
        model.train()

        # Create batch
        batch_data = make_bert_batch(tokenized_ids, word2key, batch_size=batch_size)
        input_ids, segment_ids, masked_tokens, masked_pos, is_next = zip(*batch_data)

        # Convert to tensors
        input_ids = torch.LongTensor(input_ids).to(device)
        segment_ids = torch.LongTensor(segment_ids).to(device)
        masked_tokens = torch.LongTensor(masked_tokens).to(device)
        masked_pos = torch.LongTensor(masked_pos).to(device)
        is_next = torch.LongTensor(is_next).to(device)

        # Forward pass
        optimizer.zero_grad()
        mlm_logits, nsp_logits = model(input_ids, segment_ids, masked_pos)

        # Calculate losses
        loss_mlm = criterion_mlm(
            mlm_logits.view(-1, vocab_size), masked_tokens.view(-1)
        )
        loss_nsp = criterion_nsp(nsp_logits, is_next)
        loss = loss_mlm + loss_nsp

        # Backward pass
        loss.backward()
        optimizer.step()

        if loss.item() < best_loss:
            best_loss = loss.item()
            best_model = model

        total_loss += loss.item()

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{n_epochs}, Loss: {loss.item():.4f}")

    total_time = time.perf_counter() - strat_time
    avg_train_loss = total_loss / n_epochs
    print("BERT training completed!")
    return best_model, avg_train_loss, total_time

In [None]:
best_model, avg_train_loss, total_time = train_bert()

In [None]:
print(f"Avg. training loss : {avg_train_loss:.2f}")
print(f"Training time(ms)  : {total_time:.2f}")

<h3>Train Loss</h3>
<img src="bert train loss.png" alt="BERT" width="500" />


# Load model and Inference

In [None]:
path_ = os.path.join(os.getcwd(), "app", "saved_models")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint = torch.load(os.path.join(path_, "bert_trained.pth"))

word2key = checkpoint["word2id"]  #  index mapping
key2word = checkpoint["id2word"]  # index → token mapping
vocab_size = checkpoint["vocab_size"]

# Recreate BERT with the EXACT same configuration used during training
model = BERT(
    vocab_size=vocab_size,
    d_model=checkpoint.get("d_model", 512),
    n_layers=checkpoint.get("n_layers", 8),
    n_heads=checkpoint.get("n_heads", 8),
    max_len=checkpoint.get("max_len", 128),
)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()


def tokenize_sentence(sentence, word2key, max_len):
    # Tokenize the sentence using whitespace as a separator
    tokens = sentence.split()  # This is a very simple tokenization
    # Convert each token to its corresponding ID (if not found, use the [UNK] token ID)
    token_ids = [word2key.get(word, word2key.get("[UNK]")) for word in tokens]
    # Padding or truncating the tokens to match `max_len`
    token_ids = token_ids[:max_len]  # Truncate
    token_ids += [word2key.get("[PAD]")] * (max_len - len(token_ids))  # Pad
    return token_ids


sentences = ["The quick brown fox jumped over the lazy dog"]

# Tokenize the sentences
tokenized_ids = [
    tokenize_sentence(sentence, word2key, checkpoint["max_len"])
    for sentence in sentences
]

# Now create the batch for inference
batch = make_bert_batch(
    tokenized_sentences=tokenized_ids,  # list of list[int]
    word2id=word2key,
    max_len=checkpoint["max_len"],
    max_pred=20,  # must match training
    batch_size=1,
)
# Unpack the batch (all tensors)
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(
    torch.LongTensor, zip(*batch)
)

# Move to device
input_ids = input_ids.to(device)
segment_ids = segment_ids.to(device)
masked_pos = masked_pos.to(device)
masked_tokens = masked_tokens.to(device)
isNext = isNext.to(device)

# Run inference
with torch.no_grad():
    logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)

# Decode and display predictions

# Full input sequence (without padding)
print("Input sequence (without [PAD]):")
tokens = [key2word.get(idx.item(), "[UNK]") for idx in input_ids[0] if idx != 0]
print(" ".join(tokens))
print()

# Masked token predictions
lm_preds = logits_lm.argmax(dim=-1)  # shape: (1, max_pred)

print("Masked token predictions:")
valid_positions = 0
for i, pos in enumerate(masked_pos[0]):
    if pos.item() == 0:  # padded position
        continue
    valid_positions += 1
    true_id = masked_tokens[0][i].item()
    true_token = key2word.get(true_id, "[UNK]")
    pred_id = lm_preds[0][i].item()
    pred_token = key2word.get(pred_id, "[UNK]")
    print(
        f"  Position {pos.item():3d} | True: {true_token:15s} | Pred: {pred_token:15s}"
    )

if valid_positions == 0:
    print("(No masked tokens in this sample)")

# Next Sentence Prediction
nsp_pred = logits_nsp.argmax(dim=-1).item()
print(f"\nNext Sentence Prediction:")
print(f"  Ground truth: {'isNext' if isNext[0].item() else 'notNext'}")
print(f"  Prediction  : {'isNext' if nsp_pred else 'notNext'}")

# Save model

In [None]:
# After training loop
checkpoint = {
    # Model state
    "model_state_dict": best_model.state_dict(),
    "word2id": word2key,
    "id2word": key2word,
    "vocab_size": vocab_size,
    "d_model": d_model,
    "n_layers": n_layers,
    "n_heads": n_heads,
    "max_len": max_len,
    "n_segments": n_segments,
}

save_path = "/content/drive/MyDrive/Lab4/bert2_trained.pth"

torch.save(checkpoint, save_path)
print(f"Model saved to {save_path}")