In [None]:
import os
import glob
import random
import torch
from torch.utils.data import Dataset

from transformers import (
    BertConfig,
    BertForMaskedLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)

import sentencepiece as spm

In [None]:
TOKENIZER_MODEL = "tokenizer/unigram_32000_0.9995.model"
CHUNK_DIR = "tokenized_chunks"
CONFIG_FILE = "bert_config.json"

assert os.path.exists(TOKENIZER_MODEL)
assert os.path.exists(CHUNK_DIR)
assert os.path.exists(CONFIG_FILE)

print("✓ All paths verified")

In [None]:
sp = spm.SentencePieceProcessor()
sp.load(TOKENIZER_MODEL)

PAD_ID = sp.pad_id()
MASK_ID = sp.piece_to_id("[MASK]")

print("Vocab size:", sp.get_piece_size())
print("PAD_ID:", PAD_ID)
print("MASK_ID:", MASK_ID)

In [None]:
class BertChunkDataset(Dataset):
    def __init__(self, chunk_dir, max_files=None):
        self.files = sorted(glob.glob(f"{chunk_dir}/*.pt"))
        if max_files:
            self.files = self.files[:max_files]
        
        assert len(self.files) > 0, "No chunk files found"
        
        # Load all data into memory
        self.data = []
        print(f"Loading {len(self.files)} files...")
        for file_idx, file in enumerate(self.files):
            if file_idx % 10 == 0:
                print(f"  Loaded {file_idx}/{len(self.files)} files...")
            file_data = torch.load(file)
            self.data.extend(file_data)
        
        print(f"✓ Dataset loaded with {len(self.data)} samples")
        random.shuffle(self.data)  # Shuffle once

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "input_ids": item["input_ids"],
            "attention_mask": item["attention_mask"]
        }

# Initialize dataset with a limit for testing
dataset = BertChunkDataset(CHUNK_DIR, max_files=10)  # Start with 10 files

In [None]:
class SimpleMLMCollator:
    def __init__(self, mask_token_id, pad_token_id, vocab_size, mlm_probability=0.15):
        self.mask_token_id = mask_token_id
        self.pad_token_id = pad_token_id
        self.vocab_size = vocab_size
        self.mlm_probability = mlm_probability

    def __call__(self, examples):
        input_ids = torch.stack([e["input_ids"] for e in examples])
        attention_mask = torch.stack([e["attention_mask"] for e in examples])

        labels = input_ids.clone()

        # Do not mask padding
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        probability_matrix.masked_fill_(input_ids == self.pad_token_id, 0.0)

        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100

        # 80% -> [MASK]
        mask_replace_prob = 0.8
        mask_replace = torch.bernoulli(torch.full(labels.shape, mask_replace_prob)).bool() & masked_indices
        input_ids[mask_replace] = self.mask_token_id

        # 10% -> random token (correct probability: 0.1 / 0.15 ≈ 0.6667)
        random_replace_prob = 0.6667
        random_replace = torch.bernoulli(torch.full(labels.shape, random_replace_prob)).bool() & masked_indices & ~mask_replace
        random_tokens = torch.randint(
            low=0,
            high=self.vocab_size,
            size=labels.shape,
            dtype=torch.long
        )
        input_ids[random_replace] = random_tokens[random_replace]

        # 10% -> unchanged (already handled by not being masked or replaced)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

data_collator = SimpleMLMCollator(
    mask_token_id=MASK_ID,
    pad_token_id=PAD_ID,
    vocab_size=sp.get_piece_size(),
    mlm_probability=0.15
)

print("✓ Custom MLM collator ready")

In [None]:
# Test with a small batch
batch = [dataset[i] for i in range(4)]
out = data_collator(batch)

print("input_ids shape:", out["input_ids"].shape)
print("labels shape:", out["labels"].shape)

# Count masked tokens
masked_tokens = (out["labels"] != -100).sum().item()
total_tokens = out["labels"].numel()
print(f"Masked tokens: {masked_tokens} ({masked_tokens/total_tokens:.1%} of all tokens)")