In [2]:
import json
import random
import sentencepiece as spm
import torch

# -------- CONFIG --------
TOKENIZER_MODEL = "tokenizer/unigram_32000_0.9995.model"
INPUT_TXT = "input.txt"          
MAX_LEN = 128
STRIDE = 64
MLM_PROB = 0.15
PRINT_STEPS = True
# ------------------------

# Load tokenizer
sp = spm.SentencePieceProcessor()
sp.load(TOKENIZER_MODEL)

PAD_ID = sp.pad_id()
SEP_ID = sp.eos_id()     # reuse EOS as SEP
MASK_ID = sp.piece_to_id("[MASK]")

print("Tokenizer loaded")
print("PAD_ID:", PAD_ID, "SEP_ID:", SEP_ID, "MASK_ID:", MASK_ID)
print("-" * 80)

buffer = []
sample_count = 0

def apply_mlm(input_ids):
    labels = input_ids.clone()

    for i in range(len(input_ids)):
        if input_ids[i] == PAD_ID:
            labels[i] = -100
            continue

        if random.random() < MLM_PROB:
            prob = random.random()
            if prob < 0.8:
                input_ids[i] = MASK_ID
            elif prob < 0.9:
                input_ids[i] = random.randint(0, sp.get_piece_size() - 1)
            # else: keep token
        else:
            labels[i] = -100

    return input_ids, labels

with open(INPUT_TXT, "r", encoding="utf-8") as f:
    for line_idx, line in enumerate(f):
        line = line.strip()
        if not line:
            continue

        print(f"\n[STEP 1] Raw line {line_idx}:")
        print(line)

        tokens = sp.encode(line, out_type=int)
        tokens.append(SEP_ID)

        print("[STEP 2] Tokenized:")
        print(tokens)

        buffer.extend(tokens)

        print("[STEP 3] Buffer length:", len(buffer))

        while len(buffer) >= MAX_LEN:
            print("\n[STEP 4] Creating window")

            window = buffer[:MAX_LEN]
            buffer = buffer[STRIDE:]

            print("Window token IDs (first 30):")
            print(window[:30])

            input_ids = torch.tensor(window)

            pad_len = MAX_LEN - len(input_ids)
            if pad_len > 0:
                input_ids = torch.cat([
                    input_ids,
                    torch.full((pad_len,), PAD_ID)
                ])

            attention_mask = (input_ids != PAD_ID).long()

            print("[STEP 5] Attention mask (sum = real tokens):",
                  attention_mask.sum().item())

            print("[STEP 6] Applying MLM")
            masked_input, labels = apply_mlm(input_ids.clone())

            print("Original (decoded, first 120 tokens):")
            print(sp.decode(input_ids.tolist()[:120]))

            print("\nMasked (decoded, first 120 tokens):")
            print(sp.decode(masked_input.tolist()[:120]))

            print("\nLabels (-100 means ignored, first 30):")
            print(labels[:30].tolist())

            sample_count += 1
            print("\n✔ Sample created:", sample_count)
            print("=" * 80)

            # Stop early so output is readable
            if sample_count >= 3:
                print("\nStopping early (demo mode).")
                exit()


Tokenizer loaded
PAD_ID: 0 SEP_ID: 3 MASK_ID: 1
--------------------------------------------------------------------------------

[STEP 1] Raw line 0:
පිට පිට දෙවැනි දිනටත් දෛනික ආසාදිතයින් 600 ඉක්මවයි
[STEP 2] Tokenized:
[1279, 1279, 523, 6885, 2565, 836, 4907, 7585, 3]
[STEP 3] Buffer length: 9

[STEP 1] Raw line 1:
කොවිඩ් -19 වෛරසය ආසාදනය වූ තවත් 135 දෙනෙකු අද දිනයේ හඳුනාගත් බව යුද හමුදාපති, ජෙනරාල් ශවේන්ද්‍ර සිල්වා මහතා පවසයි.
[STEP 2] Tokenized:
[353, 3865, 902, 1244, 29, 106, 12281, 213, 31, 549, 2998, 6, 528, 2611, 5, 2173, 4861, 453, 30, 126, 4, 3]
[STEP 3] Buffer length: 31

[STEP 1] Raw line 2:
ඒ අනුව දෛනික ආසාදිතයින් ගණන 617කි.
[STEP 2] Tokenized:
[12, 49, 2565, 836, 1140, 361, 4825, 157, 4, 3]
[STEP 3] Buffer length: 41

[STEP 1] Raw line 3:
කොවිඩ් ආසාදිත බවට මේ වනවිට මෙරට දී හඳුනාගෙන ඇති සමස්ත ආසාදිතයින් ගණන 544,630කි.
[STEP 2] Tokenized:
[353, 1696, 68, 9, 235, 249, 41, 1358, 8, 582, 836, 1140, 3777, 6797, 2133, 3648, 157, 4, 3]
[STEP 3] Buffer length: 60

[STEP 1] Raw li

this code shows how to create a BERT dataset with masked language modeling using a unigram tokenizer.
since this dataset has relatively short sequences, we set max length to 128 and stride to 64.
and we use a sliding window approach to create overlapping segments.