In [None]:
import os
import glob
import random
import torch
from torch.utils.data import Dataset

from transformers import (
    BertConfig,
    BertForMaskedLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)

import sentencepiece as spm

In [None]:
TOKENIZER_MODEL = "tokenizer/unigram_32000_0.9995.model"
CHUNK_DIR = "tokenized_chunks"
CONFIG_FILE = "bert_config.json"

assert os.path.exists(TOKENIZER_MODEL)
assert os.path.exists(CHUNK_DIR)
assert os.path.exists(CONFIG_FILE)

print("✓ All paths verified")

In [None]:
sp = spm.SentencePieceProcessor()
sp.load(TOKENIZER_MODEL)

PAD_ID = sp.pad_id()
MASK_ID = sp.piece_to_id("[MASK]")

print("Vocab size:", sp.get_piece_size())
print("PAD_ID:", PAD_ID)
print("MASK_ID:", MASK_ID)

In [None]:
class BertChunkDataset(Dataset):
    def __init__(self, chunk_dir, max_files=None):
        self.files = sorted(glob.glob(f"{chunk_dir}/*.pt"))
        if max_files:
            self.files = self.files[:max_files]
        
        assert len(self.files) > 0, "No chunk files found"
        
        # Load all data into memory
        self.data = []
        print(f"Loading {len(self.files)} files...")
        for file_idx, file in enumerate(self.files):
            if file_idx % 10 == 0:
                print(f"  Loaded {file_idx}/{len(self.files)} files...")
            file_data = torch.load(file)
            self.data.extend(file_data)
        
        print(f"✓ Dataset loaded with {len(self.data)} samples")
        random.shuffle(self.data)  # Shuffle once

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "input_ids": item["input_ids"],
            "attention_mask": item["attention_mask"]
        }

# Initialize dataset with a limit for testing
dataset = BertChunkDataset(CHUNK_DIR, max_files=10)  # Start with 10 files