In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import random
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from transformers import BertConfig, BertForMaskedLM, get_linear_schedule_with_warmup
import json
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Load data
with open("qn_sequences.txt", 'r', encoding="utf-8") as f:
    sentences = [line.strip() for line in f if len(line.strip().split()) > 0]
#sentences = sentences[50000]

# Build vocab
vocab = {}
special_tokens = ['<unk>', '<pad>', '<sos>', '<eos>', '<mask>']
for token in special_tokens:
    vocab[token] = len(vocab)

word_counter = Counter()
for sent in sentences:
    word_counter.update(sent.split())

for word, _ in word_counter.most_common():
    if word not in vocab:
        vocab[word] = len(vocab)

reverse_vocab = {idx: word for word, idx in vocab.items()}

In [None]:
# Tokenizer
class SimpleTokenizer:
    def __init__(self, vocab):
        self.vocab = vocab
        self.reverse_vocab = {v: k for k, v in vocab.items()}
        self.unk_token = '<unk>'
        self.pad_token = '<pad>'
        self.sos_token = '<sos>'
        self.eos_token = '<eos>'
        self.mask_token = '<mask>'
        self.max_len = 32

    def encode(self, text, add_special_tokens=True):
        tokens = text.split()
        if add_special_tokens:
            tokens = [self.sos_token] + tokens + [self.eos_token]
        return [self.vocab.get(t, self.vocab[self.unk_token]) for t in tokens]

    def decode(self, ids):
        return ' '.join([self.reverse_vocab.get(i, self.unk_token) for i in ids])

    def pad(self, ids, max_length):
        return ids[:max_length] + [self.vocab[self.pad_token]] * max(0, max_length - len(ids))

tokenizer = SimpleTokenizer(vocab)

In [None]:
# Dataset 
class MLMDataset(Dataset):
    def __init__(self, sentences, tokenizer, max_len=32, mask_prob=0.20, prob_org = 0.2 ,seed = 42): # sentences: list of sentences
        self.sentences = sentences
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.mask_prob = mask_prob

        self.blue_group = []
        self.orange_group = []
        
        self.threshold = 0
        
        self.mask_blue = Counter()
        self.mask_orange = Counter()
        self.merge_counter = Counter()
        
        self.seed = seed
        
        self.prob_org = prob_org

        self.word_counter_init()

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx, printed = False):
        random.seed(self.seed)

        # Get one sentence
        text = self.sentences[idx] # 'nam quốc sơn hà nam đế cư'
        input_ids = self.tokenizer.encode(text) # [2, 86, 278, 12, 11, 86, 498, 293, 3]
        input_ids = self.tokenizer.pad(input_ids, self.max_len) # [2, 86, 278, 12, 11, 86, 498, 293, 3, 1, 1,..., 1]
        if printed:
            print('Input ids:',input_ids)


        labels = [-100] * len(input_ids) # [-100, -100, -100, ... , -100]
        masked_input_ids = input_ids.copy() # [2, 86, 278, 12, 11, 86, 498, 293, 3, 1, 1,..., 1]

        candidate_indices = [
            i for i in range(1, len(input_ids) - 1)
            if input_ids[i] not in [
                self.tokenizer.vocab[self.tokenizer.pad_token],
                self.tokenizer.vocab[self.tokenizer.sos_token],
                self.tokenizer.vocab[self.tokenizer.eos_token]
            ]
        ] # [1, 2, 3, 4, 5, 6, 7]: indices of maskable token
        num_to_mask = max(1, int(self.mask_prob * len(candidate_indices))) # Number of token to mask
        if printed:
            print('Number of token masked:',num_to_mask)
        mask_indices = random.sample(candidate_indices, min(len(candidate_indices), num_to_mask)) # List of indices of random token to mask [2,3,5,..]
        if printed:
            print('Random index mask token list:',mask_indices)

        for i in mask_indices: # Go through each each token in random-maskable token list
            original_token = input_ids[i]
            prob = random.random()

            if original_token in self.blue_group:
                if prob < 0.9:
                    masked_input_ids[i] = self.tokenizer.vocab[self.tokenizer.mask_token]
                else:
                    masked_input_ids[i] = original_token
                labels[i] = original_token
            elif original_token in self.orange_group:
                prob_org = random.random()
                if prob_org < self.prob_org: # 0.01 ,0.05, 0.2 or 0.4 but it feel heuristic 
                    masked_input_ids[i] = self.tokenizer.vocab[self.tokenizer.mask_token]
                    labels[i] = original_token
                else:
                    mask_replace = [tok for tok in candidate_indices if tok not in mask_indices] # List of idx of token not in random-maskable token list
                    if printed:
                        print('Index of token not in random-maskable:', mask_replace)
            
                    mask_replace = [tok for tok in mask_replace if input_ids[tok] in self.blue_group] # List of idx of token in the mask_replace that in Blue group
                    if printed:
                        print('Index of token not in random-maskable and in Blue group:',mask_replace)

                    if (len(mask_replace) != 0):
                        mask_replace = random.sample(mask_replace,1) # Get one token to mask
                        mask_replace = mask_replace[0]
                    else:
                        mask_replace = i

                    if printed:
                        print('Sampling index:',mask_replace)

                    masked_input_ids[mask_replace] = self.tokenizer.vocab[self.tokenizer.mask_token]
                    labels[mask_replace] = input_ids[mask_replace]

        attention_mask = [1 if token != self.tokenizer.vocab[self.tokenizer.pad_token] else 0 for token in input_ids]

        # Debug print (chỉ in vài sample đầu)
        if (printed):
            if idx < 2:
                print("Original:", tokenizer.decode(input_ids))
                print("Masked  :", tokenizer.decode(masked_input_ids))
                print("Labels  :", [reverse_vocab.get(l, '-') if l != -100 else '_' for l in labels])

        return {
            "input_ids": torch.tensor(masked_input_ids),
            "labels": torch.tensor(labels),
            "attention_mask": torch.tensor(attention_mask)
        }
    
    def mask_stat(self, printed = False):
        random.seed(self.seed)

        for idx in range(len(self.sentences)):
            # Get one sentence
            text = self.sentences[idx] # 'nam quốc sơn hà nam đế cư'
            input_ids = self.tokenizer.encode(text) # [2, 86, 278, 12, 11, 86, 498, 293, 3]
            input_ids = self.tokenizer.pad(input_ids, self.max_len) # [2, 86, 278, 12, 11, 86, 498, 293, 3, 1, 1,..., 1]
            if printed:
                print('Input ids:',input_ids)


            labels = [-100] * len(input_ids) # [-100, -100, -100, ... , -100]
            masked_input_ids = input_ids.copy() # [2, 86, 278, 12, 11, 86, 498, 293, 3, 1, 1,..., 1]

            candidate_indices = [
                i for i in range(1, len(input_ids) - 1)
                if input_ids[i] not in [
                    self.tokenizer.vocab[self.tokenizer.pad_token],
                    self.tokenizer.vocab[self.tokenizer.sos_token],
                    self.tokenizer.vocab[self.tokenizer.eos_token]
                ]
            ] # [1, 2, 3, 4, 5, 6, 7]: indices of maskable token
            num_to_mask = max(1, int(self.mask_prob * len(candidate_indices))) # Number of token to mask
            if printed:
                print('Number of token masked:',num_to_mask)
            mask_indices = random.sample(candidate_indices, min(len(candidate_indices), num_to_mask)) # List of indices of random token to mask [2,3,5,..]
            if printed:
                print('Random index mask token list:',mask_indices)

            for i in mask_indices: # Go through each each token in random-maskable token list
                original_token = input_ids[i]
                prob = random.random()

                if original_token in self.blue_group:
                    if prob < 0.9:
                        masked_input_ids[i] = self.tokenizer.vocab[self.tokenizer.mask_token]
                        self.mask_blue.update([self.tokenizer.decode([original_token])])
                    else:
                        masked_input_ids[i] = original_token
                    labels[i] = original_token
                elif original_token in self.orange_group:
                    prob_org = random.random()
                    if prob_org < self.prob_org: # 0.01 ,0.05, 0.2 or 0.4 but it feel heuristic 
                        masked_input_ids[i] = self.tokenizer.vocab[self.tokenizer.mask_token]
                        labels[i] = original_token
                        self.mask_orange.update([self.tokenizer.decode([original_token])])
                    else:
                        mask_replace = [tok for tok in candidate_indices if tok not in mask_indices] # List of idx of token not in random-maskable token list
                        if printed:
                            print('Index of token not in random-maskable:', mask_replace)
                
                        mask_replace = [tok for tok in mask_replace if input_ids[tok] in self.blue_group] # List of idx of token in the mask_replace that in Blue group
                        if printed:
                            print('Index of token not in random-maskable and in Blue group:',mask_replace)

                        if (len(mask_replace) != 0):
                            mask_replace = random.sample(mask_replace,1) # Get one token to mask
                            mask_replace = mask_replace[0]
                            self.mask_blue.update([self.tokenizer.decode([input_ids[mask_replace]])])
                        else:
                            mask_replace = i
                            self.mask_orange.update([self.tokenizer.decode([input_ids[mask_replace]])]) 

                        if printed:
                            print('Sampling index:',mask_replace)

                        masked_input_ids[mask_replace] = self.tokenizer.vocab[self.tokenizer.mask_token]
                        labels[mask_replace] = input_ids[mask_replace]

        self.merge_counter = self.mask_blue + self.mask_orange
    
    def word_counter_init(self):
        word_counter = Counter()
        for sent in sentences:
            word_counter.update(sent.split())
        self.threshold = int(sum([word_counter[word] for word in word_counter]) / len(word_counter))

        self.blue_group = [self.tokenizer.encode(word)[1:-1][0] for word in word_counter if word_counter[word] < self.threshold]
        self.orange_group = [self.tokenizer.encode(word)[1:-1][0] for word in word_counter if word_counter[word] >= self.threshold]

In [None]:
# Split dataset 

# ['<unk>', '<pad>', '<sos>', '<eos>', '<mask>']

train_sentences, val_sentences = train_test_split(sentences, test_size=0.2, random_state=42)
train_dataset = MLMDataset(train_sentences, tokenizer)
val_dataset = MLMDataset(val_sentences, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512)

In [None]:
# Model config
config = BertConfig(
    vocab_size=len(vocab),
    hidden_size=768, # 512
    num_hidden_layers=12, # 6
    num_attention_heads=12, # 8
    intermediate_size=1024,
    max_position_embeddings=tokenizer.max_len,
    pad_token_id=vocab['<pad>']
)

model = BertForMaskedLM(config).to(device)

In [None]:
# Optimizer and Scheduler
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
total_steps = len(train_loader) * 20
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.2 * total_steps),
    num_training_steps=total_steps
)

In [None]:
# Training function
def train_epoch(model, loader):
    model.train()
    total_loss = 0
    for batch in tqdm(loader, desc="Training"):
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# Evaluation function
def eval_model(model, loader):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()

            preds = torch.argmax(outputs.logits, dim=-1)
            mask_positions = labels != -100
            correct += ((preds == labels) & mask_positions).sum().item()
            total += mask_positions.sum().item()

    acc = correct / total if total > 0 else 0
    return total_loss / len(loader), acc

In [None]:
# Training loop
patience_counter = 0
best_val_acc = 0
min_delta = 0.002
patience = 5
epochs = 50
for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    train_loss = train_epoch(model, train_loader)
    val_loss, val_acc = eval_model(model, val_loader)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.4f}")
    if val_acc - best_val_acc > min_delta:
        best_val_acc = val_acc
        patience_counter = 0
        model.save_pretrained("bert_custom_mlm_best")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

# Save model and vocab
# model.save_pretrained("./bert_custom_mlm_best")
with open("custom_vocab.json", "w", encoding="utf-8") as f:
    json.dump(vocab, f, ensure_ascii=False, indent=4)