In [1]:
import json
import random
import sentencepiece as spm
import torch

# -------- CONFIG --------
TOKENIZER_MODEL = "tokenizer/unigram_32000_0.9995.model"
INPUT_TXT = "input.txt"          
MAX_LEN = 128
STRIDE = 64
MLM_PROB = 0.15
PRINT_STEPS = True
# ------------------------

# Load tokenizer
sp = spm.SentencePieceProcessor()
sp.load(TOKENIZER_MODEL)

PAD_ID = sp.pad_id()
SEP_ID = sp.eos_id()     # reuse EOS as SEP
MASK_ID = sp.piece_to_id("[MASK]")

print("Tokenizer loaded")
print("PAD_ID:", PAD_ID, "SEP_ID:", SEP_ID, "MASK_ID:", MASK_ID)
print("-" * 80)

buffer = []
sample_count = 0

def apply_mlm(input_ids):
    labels = input_ids.clone()

    for i in range(len(input_ids)):
        if input_ids[i] == PAD_ID:
            labels[i] = -100
            continue

        if random.random() < MLM_PROB:
            prob = random.random()
            if prob < 0.8:
                input_ids[i] = MASK_ID
            elif prob < 0.9:
                input_ids[i] = random.randint(0, sp.get_piece_size() - 1)
            # else: keep token
        else:
            labels[i] = -100

    return input_ids, labels

with open(INPUT_TXT, "r", encoding="utf-8") as f:
    for line_idx, line in enumerate(f):
        line = line.strip()
        if not line:
            continue

        print(f"\n[STEP 1] Raw line {line_idx}:")
        print(line)

        tokens = sp.encode(line, out_type=int)
        tokens.append(SEP_ID)

        print("[STEP 2] Tokenized:")
        print(tokens)

        buffer.extend(tokens)

        print("[STEP 3] Buffer length:", len(buffer))

        while len(buffer) >= MAX_LEN:
            print("\n[STEP 4] Creating window")

            window = buffer[:MAX_LEN]
            buffer = buffer[STRIDE:]

            print("Window token IDs (first 30):")
            print(window[:30])

            input_ids = torch.tensor(window)

            pad_len = MAX_LEN - len(input_ids)
            if pad_len > 0:
                input_ids = torch.cat([
                    input_ids,
                    torch.full((pad_len,), PAD_ID)
                ])

            attention_mask = (input_ids != PAD_ID).long()

            print("[STEP 5] Attention mask (sum = real tokens):",
                  attention_mask.sum().item())

            print("[STEP 6] Applying MLM")
            masked_input, labels = apply_mlm(input_ids.clone())

            print("Original (decoded, first 120 tokens):")
            print(sp.decode(input_ids.tolist()[:120]))

            print("\nMasked (decoded, first 120 tokens):")
            print(sp.decode(masked_input.tolist()[:120]))

            print("\nLabels (-100 means ignored, first 30):")
            print(labels[:30].tolist())

            sample_count += 1
            print("\n✔ Sample created:", sample_count)
            print("=" * 80)

            # Stop early so output is readable
            if sample_count >= 3:
                print("\nStopping early (demo mode).")
                exit()


Tokenizer loaded
PAD_ID: 0 SEP_ID: 3 MASK_ID: 4
--------------------------------------------------------------------------------

[STEP 1] Raw line 0:
පිට පිට දෙවැනි දිනටත් දෛනික ආසාදිතයින් 600 ඉක්මවයි
[STEP 2] Tokenized:
[1280, 1280, 523, 6892, 2569, 837, 4908, 7595, 3]
[STEP 3] Buffer length: 9

[STEP 1] Raw line 1:
කොවිඩ් -19 වෛරසය ආසාදනය වූ තවත් 135 දෙනෙකු අද දිනයේ හඳුනාගත් බව යුද හමුදාපති, ජෙනරාල් ශවේන්ද්‍ර සිල්වා මහතා පවසයි.
[STEP 2] Tokenized:
[354, 3862, 903, 1245, 30, 106, 12287, 214, 32, 550, 3005, 7, 526, 2621, 6, 2173, 4865, 453, 31, 127, 5, 3]
[STEP 3] Buffer length: 31

[STEP 1] Raw line 2:
ඒ අනුව දෛනික ආසාදිතයින් ගණන 617කි.
[STEP 2] Tokenized:
[13, 50, 2569, 837, 1142, 362, 4788, 158, 5, 3]
[STEP 3] Buffer length: 41

[STEP 1] Raw line 3:
කොවිඩ් ආසාදිත බවට මේ වනවිට මෙරට දී හඳුනාගෙන ඇති සමස්ත ආසාදිතයින් ගණන 544,630කි.
[STEP 2] Tokenized:
[354, 1696, 69, 10, 236, 250, 42, 1357, 9, 583, 837, 1142, 3772, 6827, 2139, 3655, 158, 5, 3]
[STEP 3] Buffer length: 60

[STEP 1] Raw l

this code shows how to create a BERT dataset with masked language modeling using a unigram tokenizer.
since this dataset has relatively short sequences, we set max length to 128 and stride to 64.
and we use a sliding window approach to create overlapping segments.

# NOTE
in the above code instead of using a `<mask>` token i have used `<unk>` token for masking. but it should **NOT** be used in actual implementation.

In [2]:
import sentencepiece as spm
import torch
from tqdm import tqdm
import os

# ---------------- CONFIG ----------------
TOKENIZER_MODEL = "tokenizer/unigram_32000_0.9995.model"
INPUT_TXT = "combined.txt"
OUTPUT_PT = "bert_dataset_256.pt"

MAX_LEN = 256
STRIDE = 128
BATCH_SIZE = 1000000  # Write every 10k sequences to disk
# ---------------------------------------

assert os.path.exists(TOKENIZER_MODEL), "Tokenizer model not found"
assert os.path.exists(INPUT_TXT), "Input text file not found"

# Load tokenizer
sp = spm.SentencePieceProcessor()
sp.load(TOKENIZER_MODEL)

PAD_ID = sp.pad_id()
SEP_ID = sp.eos_id()

if PAD_ID == -1 or SEP_ID == -1:
    raise ValueError("Tokenizer must have PAD and EOS tokens")

print("Tokenizer loaded")
print("Vocab size:", sp.get_piece_size())
print("PAD_ID:", PAD_ID, "SEP_ID:", SEP_ID)

buffer = []
batch = []
total_tokens = 0
total_windows = 0

def save_batch(batch, filepath, mode='ab'):
    """Save batch to disk incrementally"""
    if not batch:
        return
    
    # Convert to tensors only when saving
    tensors = [
        {
            "input_ids": torch.tensor(item["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(item["attention_mask"], dtype=torch.long)
        }
        for item in batch
    ]
    
    # Append to file
    if mode == 'wb' or not os.path.exists(filepath):
        torch.save(tensors, filepath)
    else:
        # Load existing, append, save
        existing = torch.load(filepath)
        existing.extend(tensors)
        torch.save(existing, filepath)

# Get total lines for progress bar
print("\nCounting lines...")
with open(INPUT_TXT, "r", encoding="utf-8") as f:
    total_lines = sum(1 for _ in f)

print(f"Total lines: {total_lines:,}")
print("Tokenizing and saving in batches...")

with open(INPUT_TXT, "r", encoding="utf-8") as f:
    for line in tqdm(f, total=total_lines, desc="Processing", 
                     unit="lines", smoothing=0.05):
        line = line.strip()
        if not line:
            continue

        tokens = sp.encode(line, out_type=int)
        tokens.append(SEP_ID)

        buffer.extend(tokens)
        total_tokens += len(tokens)

        while len(buffer) >= MAX_LEN:
            window = buffer[:MAX_LEN]
            buffer = buffer[STRIDE:]

            # Store as lists (much lighter than tensors)
            batch.append({
                "input_ids": window,
                "attention_mask": [1] * MAX_LEN
            })
            total_windows += 1

            # Save batch periodically
            if len(batch) >= BATCH_SIZE:
                save_batch(batch, OUTPUT_PT, mode='ab' if total_windows > BATCH_SIZE else 'wb')
                batch = []  # Clear memory

# Handle final remainder
if buffer:
    pad_len = MAX_LEN - len(buffer)
    input_ids = buffer + [PAD_ID] * pad_len
    attention_mask = [1] * len(buffer) + [0] * pad_len

    batch.append({
        "input_ids": input_ids,
        "attention_mask": attention_mask
    })
    total_windows += 1

# Save final batch
if batch:
    save_batch(batch, OUTPUT_PT, mode='ab' if total_windows > len(batch) else 'wb')

print("\n=== DATASET SUMMARY ===")
print(f"Total tokens seen  : {total_tokens:,}")
print(f"Total windows      : {total_windows:,}")
print(f"Window length      : {MAX_LEN}")
print(f"Stride             : {STRIDE}")
print(f"\nSaved to: {OUTPUT_PT}")
print("Done.")

Tokenizer loaded
Vocab size: 32000
PAD_ID: 0 SEP_ID: 3

Counting lines...
Total lines: 26,557,763
Tokenizing and saving in batches...


Processing:  24%|██▎       | 6254889/26557763 [02:53<09:22, 36112.76lines/s]


KeyboardInterrupt: 

In [None]:
import sentencepiece as spm
import torch
from tqdm import tqdm
import os
import multiprocessing as mp
from functools import partial

# ---------------- CONFIG ----------------
TOKENIZER_MODEL = "tokenizer/unigram_32000_0.9995.model"
INPUT_TXT = "combined.txt"
OUTPUT_DIR = "bert_shards"
FINAL_OUTPUT = "bert_dataset_256.pt"

MAX_LEN = 256
STRIDE = 128
NUM_SHARDS = mp.cpu_count()  # Use all CPU cores
MERGE_AFTER = True  # Set False to keep shards separate
# ---------------------------------------

os.makedirs(OUTPUT_DIR, exist_ok=True)

def process_shard(shard_id, lines, tokenizer_model, output_dir, max_len, stride):
    """Process a single shard of data"""
    # Load tokenizer in worker process
    sp = spm.SentencePieceProcessor()
    sp.load(tokenizer_model)
    
    PAD_ID = sp.pad_id()
    SEP_ID = sp.eos_id()
    
    buffer = []
    dataset = []
    total_tokens = 0
    
    pbar = tqdm(lines, desc=f"Shard {shard_id}", position=shard_id, leave=True)
    
    for line in pbar:
        line = line.strip()
        if not line:
            continue
        
        tokens = sp.encode(line, out_type=int)
        tokens.append(SEP_ID)
        
        buffer.extend(tokens)
        total_tokens += len(tokens)
        
        while len(buffer) >= max_len:
            window = buffer[:max_len]
            buffer = buffer[stride:]
            
            dataset.append({
                "input_ids": torch.tensor(window, dtype=torch.long),
                "attention_mask": torch.ones(max_len, dtype=torch.long)
            })
            
            # Update progress with stats
            pbar.set_postfix({
                'tokens': f"{total_tokens:,}",
                'windows': f"{len(dataset):,}"
            })
    
    # Handle remainder
    if buffer:
        pad_len = max_len - len(buffer)
        input_ids = buffer + [PAD_ID] * pad_len
        attention_mask = [1] * len(buffer) + [0] * pad_len
        
        dataset.append({
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long)
        })
    
    # Save shard
    output_path = os.path.join(output_dir, f"shard_{shard_id}.pt")
    torch.save(dataset, output_path)
    
    return {
        'shard_id': shard_id,
        'total_tokens': total_tokens,
        'total_windows': len(dataset),
        'output_path': output_path
    }

def split_file_into_shards(filepath, num_shards):
    """Split file into roughly equal shards by line count"""
    print("\nCounting lines...")
    with open(filepath, "r", encoding="utf-8") as f:
        total_lines = sum(1 for _ in f)
    
    print(f"Total lines: {total_lines:,}")
    print(f"Lines per shard: ~{total_lines // num_shards:,}")
    
    lines_per_shard = total_lines // num_shards
    shards = [[] for _ in range(num_shards)]
    
    print("\nReading and distributing lines...")
    with open(filepath, "r", encoding="utf-8") as f:
        for i, line in enumerate(tqdm(f, total=total_lines, desc="Loading")):
            shard_idx = min(i // lines_per_shard, num_shards - 1)
            shards[shard_idx].append(line)
    
    return shards

def merge_shards(output_dir, num_shards, final_output):
    """Merge all shards into a single file"""
    print("\nMerging shards...")
    merged_dataset = []
    
    for shard_id in tqdm(range(num_shards), desc="Merging"):
        shard_path = os.path.join(output_dir, f"shard_{shard_id}.pt")
        shard_data = torch.load(shard_path)
        merged_dataset.extend(shard_data)
    
    print(f"Saving merged dataset ({len(merged_dataset):,} windows)...")
    torch.save(merged_dataset, final_output)
    
    # Optional: cleanup shards
    print("\nCleanup shards? (y/n): ", end='')
    if input().lower() == 'y':
        for shard_id in range(num_shards):
            os.remove(os.path.join(output_dir, f"shard_{shard_id}.pt"))
        print("Shards deleted.")

if __name__ == "__main__":
    assert os.path.exists(TOKENIZER_MODEL), "Tokenizer model not found"
    assert os.path.exists(INPUT_TXT), "Input text file not found"
    
    # Detect CPU cores
    num_cores = mp.cpu_count()
    print(f"Detected {num_cores} CPU cores")
    print(f"Will create {NUM_SHARDS} shards")
    
    # Load tokenizer to check vocab
    sp = spm.SentencePieceProcessor()
    sp.load(TOKENIZER_MODEL)
    print("\nTokenizer loaded")
    print("Vocab size:", sp.get_piece_size())
    print("PAD_ID:", sp.pad_id(), "SEP_ID:", sp.eos_id())
    
    # Split file into shards
    shards = split_file_into_shards(INPUT_TXT, NUM_SHARDS)
    
    # Process shards in parallel
    print(f"\nProcessing {NUM_SHARDS} shards in parallel...")
    print("=" * 60)
    
    with mp.Pool(processes=NUM_SHARDS) as pool:
        process_func = partial(
            process_shard,
            tokenizer_model=TOKENIZER_MODEL,
            output_dir=OUTPUT_DIR,
            max_len=MAX_LEN,
            stride=STRIDE
        )
        
        results = pool.starmap(
            process_func,
            [(i, shards[i]) for i in range(NUM_SHARDS)]
        )
    
    # Print summary
    print("\n" + "=" * 60)
    print("=== PROCESSING COMPLETE ===")
    total_tokens = sum(r['total_tokens'] for r in results)
    total_windows = sum(r['total_windows'] for r in results)
    
    print(f"\nTotal tokens processed: {total_tokens:,}")
    print(f"Total windows created: {total_windows:,}")
    print(f"\nPer-shard breakdown:")
    for r in results:
        print(f"  Shard {r['shard_id']}: {r['total_tokens']:,} tokens, "
              f"{r['total_windows']:,} windows")
    
    # Merge if requested
    if MERGE_AFTER:
        merge_shards(OUTPUT_DIR, NUM_SHARDS, FINAL_OUTPUT)
        print(f"\n✓ Final dataset saved to: {FINAL_OUTPUT}")
    else:
        print(f"\n✓ Shards saved in: {OUTPUT_DIR}/")
        print("  Use these directly in your DataLoader for distributed training")
    
    print("\nDone!")

Detected 32 CPU cores
Will create 32 shards

Tokenizer loaded
Vocab size: 32000
PAD_ID: 0 SEP_ID: 3

Counting lines...
Total lines: 26,557,763
Lines per shard: ~829,930

Reading and distributing lines...


Loading: 100%|██████████| 26557763/26557763 [00:24<00:00, 1063321.13it/s]


Processing 32 shards in parallel...



Shard 0:   0%|          | 1732/829930 [00:01<01:34, 8789.75it/s, tokens=42,376, windows=330]
Shard 0:   0%|          | 1732/829930 [00:01<01:34, 8789.75it/s, tokens=42,519, windows=331]
Shard 0:   0%|          | 1732/829930 [00:01<01:34, 8789.75it/s, tokens=42,631, windows=332]
Shard 0:   0%|          | 1732/829930 [00:01<01:34, 8789.75it/s, tokens=42,764, windows=333]
Shard 0:   0%|          | 1732/829930 [00:01<01:34, 8789.75it/s, tokens=42,889, windows=334]
Shard 0:   0%|          | 1732/829930 [00:01<01:34, 8789.75it/s, tokens=43,011, windows=335]
Shard 0:   0%|          | 1732/829930 [00:01<01:34, 8789.75it/s, tokens=43,159, windows=336]
Shard 0:   0%|          | 1732/829930 [00:01<01:34, 8789.75it/s, tokens=43,277, windows=337]
Shard 0:   0%|          | 1732/829930 [00:01<01:34, 8789.75it/s, tokens=43,400, windows=338]
Shard 0:   0%|          | 1732/829930 [00:01<01:34, 8789.75it/s, tokens=43,570, windows=339]
Shard 0:   0%|          | 1732/829930 [00:01<01:34, 8789.75it/s, toke

In [None]:
import sentencepiece as spm
import torch
from tqdm import tqdm
import os
import multiprocessing as mp
from functools import partial

# ---------------- CONFIG ----------------
TOKENIZER_MODEL = "tokenizer/unigram_32000_0.9995.model"
INPUT_TXT = "combined.txt"
OUTPUT_DIR = "bert_shards"
FINAL_OUTPUT = "bert_dataset_256.pt"

MAX_LEN = 256
STRIDE = 128
NUM_SHARDS = 4  # Split into 4 parts
MERGE_AFTER = True  # Set False to keep shards separate
# ---------------------------------------

os.makedirs(OUTPUT_DIR, exist_ok=True)

def process_shard(shard_id, lines, tokenizer_model, output_dir, max_len, stride):
    """Process a single shard of data"""
    # Load tokenizer in worker process
    sp = spm.SentencePieceProcessor()
    sp.load(tokenizer_model)
    
    PAD_ID = sp.pad_id()
    SEP_ID = sp.eos_id()
    
    buffer = []
    dataset = []
    total_tokens = 0
    
    pbar = tqdm(lines, desc=f"Shard {shard_id}", position=shard_id, leave=True)
    
    for line in pbar:
        line = line.strip()
        if not line:
            continue
        
        tokens = sp.encode(line, out_type=int)
        tokens.append(SEP_ID)
        
        buffer.extend(tokens)
        total_tokens += len(tokens)
        
        while len(buffer) >= max_len:
            window = buffer[:max_len]
            buffer = buffer[stride:]
            
            dataset.append({
                "input_ids": torch.tensor(window, dtype=torch.long),
                "attention_mask": torch.ones(max_len, dtype=torch.long)
            })
            
            # Update progress with stats
            pbar.set_postfix({
                'tokens': f"{total_tokens:,}",
                'windows': f"{len(dataset):,}"
            })
    
    # Handle remainder
    if buffer:
        pad_len = max_len - len(buffer)
        input_ids = buffer + [PAD_ID] * pad_len
        attention_mask = [1] * len(buffer) + [0] * pad_len
        
        dataset.append({
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long)
        })
    
    # Save shard
    output_path = os.path.join(output_dir, f"shard_{shard_id}.pt")
    torch.save(dataset, output_path)
    
    return {
        'shard_id': shard_id,
        'total_tokens': total_tokens,
        'total_windows': len(dataset),
        'output_path': output_path
    }

def split_file_into_shards(filepath, num_shards):
    """Split file into roughly equal shards by line count"""
    print("\nCounting lines...")
    with open(filepath, "r", encoding="utf-8") as f:
        total_lines = sum(1 for _ in f)
    
    print(f"Total lines: {total_lines:,}")
    print(f"Lines per shard: ~{total_lines // num_shards:,}")
    
    lines_per_shard = total_lines // num_shards
    shards = [[] for _ in range(num_shards)]
    
    print("\nReading and distributing lines...")
    with open(filepath, "r", encoding="utf-8") as f:
        for i, line in enumerate(tqdm(f, total=total_lines, desc="Loading")):
            shard_idx = min(i // lines_per_shard, num_shards - 1)
            shards[shard_idx].append(line)
    
    return shards

def merge_shards(output_dir, num_shards, final_output):
    """Merge all shards into a single file"""
    print("\nMerging shards...")
    merged_dataset = []
    
    for shard_id in tqdm(range(num_shards), desc="Merging"):
        shard_path = os.path.join(output_dir, f"shard_{shard_id}.pt")
        shard_data = torch.load(shard_path)
        merged_dataset.extend(shard_data)
    
    print(f"Saving merged dataset ({len(merged_dataset):,} windows)...")
    torch.save(merged_dataset, final_output)
    
    # Optional: cleanup shards
    print("\nCleanup shards? (y/n): ", end='')
    if input().lower() == 'y':
        for shard_id in range(num_shards):
            os.remove(os.path.join(output_dir, f"shard_{shard_id}.pt"))
        print("Shards deleted.")

if __name__ == "__main__":
    assert os.path.exists(TOKENIZER_MODEL), "Tokenizer model not found"
    assert os.path.exists(INPUT_TXT), "Input text file not found"
    
    # Load tokenizer to check vocab
    sp = spm.SentencePieceProcessor()
    sp.load(TOKENIZER_MODEL)
    print("Tokenizer loaded")
    print("Vocab size:", sp.get_piece_size())
    print("PAD_ID:", sp.pad_id(), "SEP_ID:", sp.eos_id())
    
    # Split file into shards
    shards = split_file_into_shards(INPUT_TXT, NUM_SHARDS)
    
    # Process shards in parallel
    print(f"\nProcessing {NUM_SHARDS} shards in parallel...")
    print("=" * 60)
    
    with mp.Pool(processes=NUM_SHARDS) as pool:
        process_func = partial(
            process_shard,
            tokenizer_model=TOKENIZER_MODEL,
            output_dir=OUTPUT_DIR,
            max_len=MAX_LEN,
            stride=STRIDE
        )
        
        results = pool.starmap(
            process_func,
            [(i, shards[i]) for i in range(NUM_SHARDS)]
        )
    
    # Print summary
    print("\n" + "=" * 60)
    print("=== PROCESSING COMPLETE ===")
    total_tokens = sum(r['total_tokens'] for r in results)
    total_windows = sum(r['total_windows'] for r in results)
    
    print(f"\nTotal tokens processed: {total_tokens:,}")
    print(f"Total windows created: {total_windows:,}")
    print(f"\nPer-shard breakdown:")
    for r in results:
        print(f"  Shard {r['shard_id']}: {r['total_tokens']:,} tokens, "
              f"{r['total_windows']:,} windows")
    
    # Merge if requested
    if MERGE_AFTER:
        merge_shards(OUTPUT_DIR, NUM_SHARDS, FINAL_OUTPUT)
        print(f"\n✓ Final dataset saved to: {FINAL_OUTPUT}")
    else:
        print(f"\n✓ Shards saved in: {OUTPUT_DIR}/")
        print("  Use these directly in your DataLoader for distributed training")
    
    print("\nDone!")

Tokenizer loaded
Vocab size: 32000
PAD_ID: 0 SEP_ID: 3

Counting lines...
Total lines: 26,557,763
Lines per shard: ~6,639,440

Reading and distributing lines...


Loading: 100%|██████████| 26557763/26557763 [00:24<00:00, 1063681.42it/s]



Processing 4 shards in parallel...


Shard 0:   0%|          | 23583/6639440 [00:03<11:41, 9432.17it/s, tokens=476,950, windows=3,725]

In [None]:
import sentencepiece as spm
import torch
from tqdm import tqdm
import os
import multiprocessing as mp
from functools import partial

# ---------------- CONFIG ----------------
TOKENIZER_MODEL = "tokenizer/unigram_32000_0.9995.model"
INPUT_TXT = "combined.txt"
OUTPUT_DIR = "bert_shards"
FINAL_OUTPUT = "bert_dataset_256.pt"

MAX_LEN = 256
STRIDE = 128
NUM_SHARDS = mp.cpu_count()  # Use all CPU cores
MERGE_AFTER = True  # Set False to keep shards separate
# ---------------------------------------

os.makedirs(OUTPUT_DIR, exist_ok=True)

def process_shard(shard_id, start_line, end_line, filepath, tokenizer_model, output_dir, max_len, stride):
    """Process a single shard of data by reading only assigned lines"""
    # Load tokenizer in worker process
    sp = spm.SentencePieceProcessor()
    sp.load(tokenizer_model)
    
    PAD_ID = sp.pad_id()
    SEP_ID = sp.eos_id()
    
    buffer = []
    batch = []
    total_tokens = 0
    total_windows = 0
    BATCH_SIZE = 10000  # Save every 10k windows to avoid memory issues
    
    output_path = os.path.join(output_dir, f"shard_{shard_id}.pt")
    
    # Read only the assigned line range
    with open(filepath, "r", encoding="utf-8") as f:
        pbar = tqdm(total=end_line - start_line, desc=f"Shard {shard_id}", 
                   position=shard_id, leave=True)
        
        for line_num, line in enumerate(f):
            # Skip until we reach our start
            if line_num < start_line:
                continue
            if line_num >= end_line:
                break
            
            line = line.strip()
            if not line:
                pbar.update(1)
                continue
            
            tokens = sp.encode(line, out_type=int)
            tokens.append(SEP_ID)
            
            buffer.extend(tokens)
            total_tokens += len(tokens)
            
            while len(buffer) >= max_len:
                window = buffer[:max_len]
                buffer = buffer[stride:]
                
                # Store as list first (lighter than tensor)
                batch.append({
                    "input_ids": window,
                    "attention_mask": [1] * max_len
                })
                total_windows += 1
                
                # Periodically save batch to disk
                if len(batch) >= BATCH_SIZE:
                    save_batch_to_disk(batch, output_path, 
                                      mode='ab' if total_windows > BATCH_SIZE else 'wb')
                    batch = []
            
            pbar.update(1)
            pbar.set_postfix({
                'tokens': f"{total_tokens:,}",
                'windows': f"{total_windows:,}"
            })
        
        pbar.close()
    
    # Handle remainder
    if buffer:
        pad_len = max_len - len(buffer)
        input_ids = buffer + [PAD_ID] * pad_len
        attention_mask = [1] * len(buffer) + [0] * pad_len
        
        batch.append({
            "input_ids": input_ids,
            "attention_mask": attention_mask
        })
        total_windows += 1
    
    # Save final batch
    if batch:
        save_batch_to_disk(batch, output_path, 
                          mode='ab' if total_windows > len(batch) else 'wb')
    
    return {
        'shard_id': shard_id,
        'total_tokens': total_tokens,
        'total_windows': total_windows,
        'output_path': output_path
    }

def save_batch_to_disk(batch, filepath, mode='ab'):
    """Save batch to disk, converting lists to tensors"""
    if not batch:
        return
    
    tensors = [
        {
            "input_ids": torch.tensor(item["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(item["attention_mask"], dtype=torch.long)
        }
        for item in batch
    ]
    
    if mode == 'wb' or not os.path.exists(filepath):
        torch.save(tensors, filepath)
    else:
        existing = torch.load(filepath)
        existing.extend(tensors)
        torch.save(existing, filepath)

def get_line_ranges(filepath, num_shards):
    """Calculate line ranges for each shard without loading file into memory"""
    print("\nCounting lines...")
    with open(filepath, "r", encoding="utf-8") as f:
        total_lines = sum(1 for _ in tqdm(f, desc="Counting"))
    
    print(f"Total lines: {total_lines:,}")
    lines_per_shard = total_lines // num_shards
    
    ranges = []
    for i in range(num_shards):
        start = i * lines_per_shard
        end = (i + 1) * lines_per_shard if i < num_shards - 1 else total_lines
        ranges.append((start, end))
        print(f"  Shard {i}: lines {start:,} to {end:,} ({end-start:,} lines)")
    
    return ranges, total_lines

def merge_shards(output_dir, num_shards, final_output):
    """Merge all shards into a single file"""
    print("\nMerging shards...")
    merged_dataset = []
    
    for shard_id in tqdm(range(num_shards), desc="Merging"):
        shard_path = os.path.join(output_dir, f"shard_{shard_id}.pt")
        shard_data = torch.load(shard_path)
        merged_dataset.extend(shard_data)
    
    print(f"Saving merged dataset ({len(merged_dataset):,} windows)...")
    torch.save(merged_dataset, final_output)
    
    # Optional: cleanup shards
    print("\nCleanup shards? (y/n): ", end='')
    if input().lower() == 'y':
        for shard_id in range(num_shards):
            os.remove(os.path.join(output_dir, f"shard_{shard_id}.pt"))
        print("Shards deleted.")

if __name__ == "__main__":
    assert os.path.exists(TOKENIZER_MODEL), "Tokenizer model not found"
    assert os.path.exists(INPUT_TXT), "Input text file not found"
    
    # Detect CPU cores
    num_cores = mp.cpu_count()
    print(f"Detected {num_cores} CPU cores")
    print(f"Will create {NUM_SHARDS} shards")
    
    # Load tokenizer to check vocab
    sp = spm.SentencePieceProcessor()
    sp.load(TOKENIZER_MODEL)
    print("\nTokenizer loaded")
    print("Vocab size:", sp.get_piece_size())
    print("PAD_ID:", sp.pad_id(), "SEP_ID:", sp.eos_id())
    
    # Calculate line ranges for each shard
    ranges, total_lines = get_line_ranges(INPUT_TXT, NUM_SHARDS)
    
    # Process shards in parallel
    print(f"\nProcessing {NUM_SHARDS} shards in parallel...")
    print("=" * 60)
    
    with mp.Pool(processes=NUM_SHARDS) as pool:
        process_func = partial(
            process_shard,
            filepath=INPUT_TXT,
            tokenizer_model=TOKENIZER_MODEL,
            output_dir=OUTPUT_DIR,
            max_len=MAX_LEN,
            stride=STRIDE
        )
        
        results = pool.starmap(
            process_func,
            [(i, start, end) for i, (start, end) in enumerate(ranges)]
        )
    
    # Print summary
    print("\n" + "=" * 60)
    print("=== PROCESSING COMPLETE ===")
    total_tokens = sum(r['total_tokens'] for r in results)
    total_windows = sum(r['total_windows'] for r in results)
    
    print(f"\nTotal tokens processed: {total_tokens:,}")
    print(f"Total windows created: {total_windows:,}")
    print(f"\nPer-shard breakdown:")
    for r in results:
        print(f"  Shard {r['shard_id']}: {r['total_tokens']:,} tokens, "
              f"{r['total_windows']:,} windows")
    
    # Merge if requested
    if MERGE_AFTER:
        merge_shards(OUTPUT_DIR, NUM_SHARDS, FINAL_OUTPUT)
        print(f"\n✓ Final dataset saved to: {FINAL_OUTPUT}")
    else:
        print(f"\n✓ Shards saved in: {OUTPUT_DIR}/")
        print("  Use these directly in your DataLoader for distributed training")
    
    print("\nDone!")

In [1]:
# split_text.py
import os
from tqdm import tqdm

INPUT_TXT = "combined.txt"
SHARD_DIR = "text_shards"
NUM_SHARDS = 100

os.makedirs(SHARD_DIR, exist_ok=True)

# Count lines (single pass)
print("Counting lines...")
with open(INPUT_TXT, "r", encoding="utf-8") as f:
    total_lines = sum(1 for _ in f)

lines_per_shard = total_lines // NUM_SHARDS
print(f"Total lines: {total_lines:,}")
print(f"Lines per shard: {lines_per_shard:,}")

writers = []
for i in range(NUM_SHARDS):
    f = open(f"{SHARD_DIR}/shard_{i}.txt", "w", encoding="utf-8")
    writers.append(f)

with open(INPUT_TXT, "r", encoding="utf-8") as f:
    for i, line in enumerate(tqdm(f, total=total_lines)):
        shard_id = min(i // lines_per_shard, NUM_SHARDS - 1)
        writers[shard_id].write(line)

for f in writers:
    f.close()

print("✓ Text sharding complete")


Counting lines...
Total lines: 26,557,763
Lines per shard: 265,577


100%|██████████| 26557763/26557763 [00:41<00:00, 636014.46it/s]

✓ Text sharding complete





In [2]:
# tokenize_shards.py
import sentencepiece as spm
import torch
import os
from tqdm import tqdm
from multiprocessing import Pool

TOKENIZER_MODEL = "tokenizer/unigram_32000_0.9995.model"
TEXT_SHARD_DIR = "text_shards"
OUT_DIR = "tokenized_chunks"

MAX_LEN = 256
STRIDE = 128
CHUNK_SIZE = 10_000   # windows per file
PARALLEL_SHARDS = 4   # VERY IMPORTANT

os.makedirs(OUT_DIR, exist_ok=True)

def tokenize_text_shard(shard_id):
    sp = spm.SentencePieceProcessor()
    sp.load(TOKENIZER_MODEL)

    PAD = sp.pad_id()
    EOS = sp.eos_id()

    buffer = []
    chunk = []
    chunk_id = 0

    in_path = f"{TEXT_SHARD_DIR}/shard_{shard_id}.txt"

    with open(in_path, "r", encoding="utf-8") as f:
        for line in tqdm(f, desc=f"Shard {shard_id}"):
            tokens = sp.encode(line.strip(), out_type=int)
            tokens.append(EOS)
            buffer.extend(tokens)

            while len(buffer) >= MAX_LEN:
                window = buffer[:MAX_LEN]
                buffer = buffer[STRIDE:]

                chunk.append({
                    "input_ids": torch.tensor(window, dtype=torch.long),
                    "attention_mask": torch.ones(MAX_LEN, dtype=torch.long)
                })

                if len(chunk) == CHUNK_SIZE:
                    out = f"{OUT_DIR}/s{shard_id}_c{chunk_id}.pt"
                    torch.save(chunk, out)
                    chunk.clear()
                    chunk_id += 1

    # flush remainder
    if buffer:
        pad = MAX_LEN - len(buffer)
        chunk.append({
            "input_ids": torch.tensor(buffer + [PAD]*pad),
            "attention_mask": torch.tensor([1]*len(buffer) + [0]*pad)
        })

    if chunk:
        out = f"{OUT_DIR}/s{shard_id}_c{chunk_id}.pt"
        torch.save(chunk, out)

    return shard_id


if __name__ == "__main__":
    shard_ids = list(range(100))

    for i in range(0, 100, PARALLEL_SHARDS):
        batch = shard_ids[i:i+PARALLEL_SHARDS]
        with Pool(PARALLEL_SHARDS) as p:
            p.map(tokenize_text_shard, batch)


Shard 0: 265577it [00:06, 43355.68it/s]
Shard 1: 265577it [00:06, 43314.71it/s]
Shard 2: 265577it [00:06, 42174.16it/s]
Shard 3: 265577it [00:06, 40732.77it/s]
Shard 7: 265577it [00:04, 53952.66it/s]
Shard 6: 265577it [00:05, 52524.17it/s]
Shard 5: 265577it [00:05, 47956.92it/s]
Shard 4: 265577it [00:06, 43892.23it/s]
Shard 11: 265577it [00:04, 55903.26it/s]
Shard 10: 265577it [00:04, 55729.68it/s]
Shard 8: 265577it [00:04, 54781.25it/s]
Shard 9: 265577it [00:05, 50968.22it/s]
Shard 12: 265577it [00:05, 45521.94it/s]
Shard 15: 265577it [00:05, 44294.69it/s]
Shard 13: 265577it [00:06, 42965.95it/s]
Shard 14: 265577it [00:06, 42830.11it/s]
Shard 19: 265577it [00:05, 48735.42it/s]
Shard 17: 265577it [00:05, 47988.38it/s]
Shard 18: 265577it [00:05, 47353.98it/s]
Shard 16: 265577it [00:06, 44192.73it/s]
Shard 21: 265577it [00:06, 38828.42it/s]
Shard 20: 265577it [00:06, 38365.78it/s]
Shard 22: 265577it [00:07, 37742.32it/s]
Shard 23: 265577it [00:07, 36684.39it/s]
Shard 26: 265577it [00:07,

In [3]:
# merge_chunks.py
import torch
import glob
from tqdm import tqdm

OUT_FILE = "bert_dataset_256.pt"

all_chunks = sorted(glob.glob("tokenized_chunks/*.pt"))

merged = []
for path in tqdm(all_chunks):
    data = torch.load(path)
    merged.extend(data)

    if len(merged) >= 50_000:
        torch.save(merged, OUT_FILE)
        merged.clear()

if merged:
    torch.save(merged, OUT_FILE)

print("✓ Merge complete")


100%|██████████| 922/922 [14:00<00:00,  1.10it/s]

✓ Merge complete





In [4]:
import torch

data = torch.load("tokenized_chunks/s0_c0.pt")

print(len(data))
print(data[0].keys())
print(data[0]["input_ids"].shape)
print(data[0]["attention_mask"].sum())

10000
dict_keys(['input_ids', 'attention_mask'])
torch.Size([256])
tensor(256)


In [5]:
import sentencepiece as spm
import torch

sp = spm.SentencePieceProcessor()
sp.load("tokenizer/unigram_32000_0.9995.model")

sample = torch.load("tokenized_chunks/s0_c0.pt")[0]
tokens = sample["input_ids"].tolist()

# remove padding
tokens = [t for t in tokens if t != sp.pad_id()]

print(sp.decode(tokens))

විදේශිකයෙක් දළදා මාළිගාව ට ඉහළින් ඩ්‍රෝනයක් යවයි ඩ්‍රෝන යානයක් යොදා ගනිමින් මහනුවර <unk>--user_defined_symbols=[MASK]තිහාසික ශ්‍රී දළදා මාළිගාවට අයත් අධිආරක්ෂිත කලාපයේ වීඩියෝ දර්ශන ලබාගත් බංග්ලාදේශ ජාතිකයෙකු අත්අඩංගුවට ගෙන තිබේ. සැකකරු පොලීසියට පවසා ඇත්තේ තමන් කාලයක් තිස්සේ පවත්වාගෙන යන යූ ටියුබ් නාලිකාවක් වෙනුවෙන් අදාළ දර්ශන ලබාගත් බවය. ගමන් බෑගයකින් හමුවූ කාන්තා සිරුර - ඝාතනයට හේතුව හෙළිවෙයි සපුගස්කන්ද ප්‍රදේශයේ ගමන් බෑගයක තිබියදී සොයාගත් කාන්තා මළ සිරුර හමුවීමේ සිද්ධියට අදාළව අඹුසැමි යුවළක් අද අත්අඩංගුවට ගෙන තිබුණි. පොලීසිය පැවසුවේ, කාන්තාව ඝාතනය කර සිරුර ගමන් බෑගයක දමා එම ස්ථානයට ගෙනැවිත් ඇති බවට ඔවුන්ගෙන් සිදුකරන ලද ප්‍රශ්න කිරීම්වලදී අනාවරණ වූ බවය. සූදුවට ඇබ්බැහිව සිටි එම කාන්තාවගේ රන් භාණ්ඩ සහ මුදල් ලබාගැනීමේ අරමුණින් ඝාතනය සිදුකළ බව සැකකරුවන් අනාවරණ කර තිබේ. සපුගස්කන්ද මාබිම මාර්ගයේ කසළ ගොඩක ගමන් බෑගයක බහා තිබියදී සොයාගත් මළ සිරුර මොහොමඩ් සාෂි ෆාතුමා මුම්ටාස් නැමැති 42 හැවිරිදි කාන්තාවගේ බවට ඊයේ හඳුනාගැනුණි. පොලීසිය පැවසුවේ, කොළඹ 10, රම්සා පෙදෙස, මාලිගාවත්ත මහල් නිවාස සංකීර්ණයේ

In [6]:
a = torch.load("tokenized_chunks/s0_c0.pt")[0]["input_ids"]
b = torch.load("tokenized_chunks/s0_c0.pt")[1]["input_ids"]

print(torch.equal(a[128:], b[:128]))

True


In [7]:
from collections import Counter
import torch
import random

counter = Counter()

files = random.sample(list(sorted(
    __import__("glob").glob("tokenized_chunks/*.pt")
)), 5)

for f in files:
    data = torch.load(f)
    for item in data[:100]:
        counter.update(item["input_ids"].tolist())

print(counter.most_common(10))

[(5, 6354), (3, 4163), (6, 1787), (8, 888), (7, 787), (10, 727), (24, 658), (13, 652), (12, 629), (9, 629)]


In [8]:
for f in files:
    data = torch.load(f)
    for item in data[:50]:
        assert item["input_ids"].dtype == torch.long
        assert item["attention_mask"].dtype == torch.long
        assert item["input_ids"].shape[0] == 256

In [9]:
import torch, glob
total = 0
for f in glob.glob("tokenized_chunks/*.pt"):
    total += len(torch.load(f))
print(f"Total windows: {total:,}")


Total windows: 8,603,777
