In [None]:
import ujson as json
import random
from datasets import load_dataset

CHUNK_SIZE = 50
MAX_SAMPLES = 500
MAX_CHARS = 512
OUTPUT_PATH = "pretrain_smollm.jsonl"

def write_chunk(chunk, file_path):
    """Shuffle and write chunk to disk."""
    random.shuffle(chunk)
    with open(file_path, "a", encoding="utf-8") as f:
        for item in chunk:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")


def main():
    print("ðŸ”„ Streaming smollm / cosmopedia-v2 dataset...")

    dataset = load_dataset(
        "HuggingFaceTB/smollm-corpus",
        "cosmopedia-v2",
        split="train",
        streaming=True
    )

    buffer = []
    count = 0
    chunk_id = 1

    total_tokens = 0

    for sample in dataset:

        if count >= MAX_SAMPLES:
            break

        text = (sample.get("text") or "").strip()
        if not text:
            continue

        # STRICT CHAR LIMIT
        if len(text) >= MAX_CHARS:
            continue

        token_count = text.count(" ") + 1
        total_tokens += token_count

        formatted = {
            "text": f"<|im_start|>{text}<|im_end|>"
        }

        buffer.append(formatted)
        count += 1

        if len(buffer) >= CHUNK_SIZE:
            random.shuffle(buffer)
            print(f"ðŸ§¹ Writing chunk {chunk_id} ({len(buffer)} samples)... Total so far: {count}")
            write_chunk(buffer, OUTPUT_PATH)
            buffer = []
            chunk_id += 1

    if buffer:
        print(f"ðŸ§¹ Writing final chunk ({len(buffer)} samples)...")
        write_chunk(buffer, OUTPUT_PATH)

    # ---------------- SUMMARY ----------------
    print("\nðŸŽ‰ DONE!")
    print(f"ðŸ“„ Saved to: {OUTPUT_PATH}")
    print(f"ðŸ“¦ Total samples written: {count}")
    print(f"ðŸ”¢ Total tokens (approx): {total_tokens:,}")
    print(f"ðŸ”¢ Total tokens in millions: {total_tokens / 1e6:.3f}M")
    print(f"ðŸ”¢ Total tokens in billions: {total_tokens / 1e9:.4f}B")
    print(f"ðŸ§® Avg tokens per sample: {total_tokens / count:.2f}")


if __name__ == "__main__":
    main()