In [None]:
import os
from tqdm import tqdm

INPUT_TXT = "combined.txt"
SHARD_DIR = "text_shards"
NUM_SHARDS = 100

os.makedirs(SHARD_DIR, exist_ok=True)

# Count lines (single pass)
print("Counting lines...")
with open(INPUT_TXT, "r", encoding="utf-8") as f:
    total_lines = sum(1 for _ in f)

lines_per_shard = total_lines // NUM_SHARDS
print(f"Total lines: {total_lines:,}")
print(f"Lines per shard: {lines_per_shard:,}")

writers = []
for i in range(NUM_SHARDS):
    f = open(f"{SHARD_DIR}/shard_{i}.txt", "w", encoding="utf-8")
    writers.append(f)

with open(INPUT_TXT, "r", encoding="utf-8") as f:
    for i, line in enumerate(tqdm(f, total=total_lines)):
        shard_id = min(i // lines_per_shard, NUM_SHARDS - 1)
        writers[shard_id].write(line)

for f in writers:
    f.close()

print("âœ“ Text sharding complete")

In [None]:
import sentencepiece as spm
import torch
import os
from tqdm import tqdm
from multiprocessing import Pool

TOKENIZER_MODEL = "tokenizer/unigram_32000_0.9995.model"
TEXT_SHARD_DIR = "text_shards"
OUT_DIR = "tokenized_chunks"

MAX_LEN = 256
STRIDE = 128
CHUNK_SIZE = 10_000   # windows per file
PARALLEL_SHARDS = 4   # VERY IMPORTANT

os.makedirs(OUT_DIR, exist_ok=True)

def tokenize_text_shard(shard_id):
    sp = spm.SentencePieceProcessor()
    sp.load(TOKENIZER_MODEL)

    PAD = sp.pad_id()
    EOS = sp.eos_id()

    buffer = []
    chunk = []
    chunk_id = 0

    in_path = f"{TEXT_SHARD_DIR}/shard_{shard_id}.txt"

    with open(in_path, "r", encoding="utf-8") as f:
        for line in tqdm(f, desc=f"Shard {shard_id}"):
            tokens = sp.encode(line.strip(), out_type=int)
            tokens.append(EOS)
            buffer.extend(tokens)

            while len(buffer) >= MAX_LEN:
                window = buffer[:MAX_LEN]
                buffer = buffer[STRIDE:]

                chunk.append({
                    "input_ids": torch.tensor(window, dtype=torch.long),
                    "attention_mask": torch.ones(MAX_LEN, dtype=torch.long)
                })

                if len(chunk) == CHUNK_SIZE:
                    out = f"{OUT_DIR}/s{shard_id}_c{chunk_id}.pt"
                    torch.save(chunk, out)
                    chunk.clear()
                    chunk_id += 1

    # flush remainder
    if buffer:
        pad = MAX_LEN - len(buffer)
        chunk.append({
            "input_ids": torch.tensor(buffer + [PAD]*pad),
            "attention_mask": torch.tensor([1]*len(buffer) + [0]*pad)
        })

    if chunk:
        out = f"{OUT_DIR}/s{shard_id}_c{chunk_id}.pt"
        torch.save(chunk, out)

    return shard_id


if __name__ == "__main__":
    shard_ids = list(range(100))

    for i in range(0, 100, PARALLEL_SHARDS):
        batch = shard_ids[i:i+PARALLEL_SHARDS]
        with Pool(PARALLEL_SHARDS) as p:
            p.map(tokenize_text_shard, batch)
