In [16]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import tiktoken
import math
from datasets import load_dataset
from datasets import load_from_disk
from dataclasses import dataclass
from contextlib import nullcontext

In [None]:
# Load the TinyStories dataset from the Hugging Face Hub
dataset = load_dataset("roneneldan/TinyStories")
display(dataset)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})

In [None]:
 # Preprocess and save the dataset to disk
 dataset.save_to_disk("../tinystories_dataset")

Saving the dataset (4/4 shards): 100%|██████████| 2119719/2119719 [00:01<00:00, 1697668.16 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 21990/21990 [00:00<00:00, 1390701.96 examples/s]


In [None]:
# To load the dataset from disk in future sessions, use:
ds = load_from_disk("../tinystories_dataset")
display(ds)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})

In [21]:
# Get the first example from the training split
example = ds['train'][0]

# Display the example
display(example)

{'text': 'One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.'}

In [22]:
def process(example):
    """Tokenize a single example into integer IDs and report its length.

    This function expects `example` to be a mapping (e.g., a dict) with a
    'text' field containing a string. It uses a global/tokenizer object
    `enc` with a method `encode_ordinary` to convert the text into a list/array
    of token IDs (without adding special tokens). It returns a dict with:

        - 'ids': the tokenized sequence (list/array of ints)
        - 'len': the number of tokens (int)

    Args:
        example (Mapping[str, Any]): An item from the dataset containing
            at least the key 'text' mapped to a string.

    Returns:
        dict: {'ids': <token_ids>, 'len': <length_of_token_ids>}
    """
    # Convert raw text into a sequence of token IDs using the tokenizer.
    # `encode_ordinary` typically means: no BOS/EOS or special tokens inserted.
    ids = enc.encode_ordinary(example['text'])

    # Package both the token IDs and their length for downstream batching/writing.
    out = {'ids': ids, 'len': len(ids)}

    # Return the tokenization result for this example.
    return out

In [23]:
# -----------------------------------------------------------------------------
# Dataset preprocessing to create binary memmap files per split (e.g., train.bin).
#
# This script:
#   1) Checks if "train.bin" already exists to avoid reprocessing.
#   2) Maps a tokenization function `process` over the dataset `ds`
#      (removing raw 'text' and running in parallel with num_proc=8).
#   3) For each split in the tokenized dataset, preallocates a memory-mapped
#      array sized to the total token count (sum of 'len').
#   4) Writes tokens to the memmap in contiguous shards using dataset sharding.
#
# Requirements / globals (expected to be defined elsewhere):
#   - ds: a datasets.DatasetDict or similar with text and metadata fields.
#   - process: a callable that tokenizes each example and returns fields
#              including 'ids' (token IDs, numpy arrays) and 'len' (lengths).
#   - np: NumPy
#   - tqdm: progress bar utility (from tqdm import tqdm)
#   - os: standard library os module
#
# Output:
#   - For each split (e.g., 'train', 'validation'): creates "<split>.bin"
#     containing the concatenated token IDs as a uint64 memmap.
#
# Notes:
#   - dtype is set to np.uint64 here; ensure downstream readers (e.g., get_batch)
#     use the same dtype, or perform explicit casting when loading.
#   - with_format('numpy') ensures batches yield NumPy arrays for fast concat.
#   - Sharding with contiguous=True preserves order within each shard.
# -----------------------------------------------------------------------------

if not os.path.exists("../preprocess_data/train.bin"):
    tokenized = ds.map(
        process,
        remove_columns = ['text'],
        desc = "tokenizing the splits",
        num_proc = 8,
    )

    for split, dset in tokenized.items():
        # Total number of tokens across the split; used to size the memmap.
        arr_len = np.sum(dset['len'], dtype=np.uint64)

        # Output filename for the current split.
        filename = f'../preprocess_data/{split}.bin'

        # Binary dtype for stored token IDs; keep consistent across pipeline.
        dtype = np.uint64

        # Preallocate a writable memmap of the exact total length.
        arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len))

        # Number of shards to iterate over when writing to memmap.
        total_batches = 1024

        # Current write position within the memmap.
        idx = 0

        # Iterate over shards/batches and write them sequentially into the memmap.
        for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
            # Take a contiguous shard of the dataset and return NumPy arrays.
            batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')

            # Concatenate the token ID arrays from the shard into one array.
            arr_batch = np.concatenate(batch['ids'])

            # Write the batch into the memmap slice and advance the write index.
            arr[idx: idx + len(arr_batch)] = arr_batch
            idx += len(arr_batch)

        # Ensure all buffered writes are flushed to disk.
        arr.flush()

tokenizing the splits (num_proc=8): 100%|██████████| 2119719/2119719 [00:36<00:00, 57780.92 examples/s]
tokenizing the splits (num_proc=8): 100%|██████████| 21990/21990 [00:00<00:00, 34057.11 examples/s]
writing ../preprocess_data/train.bin: 100%|██████████| 1024/1024 [04:36<00:00,  3.71it/s]
writing ../preprocess_data/validation.bin: 100%|██████████| 1024/1024 [00:03<00:00, 260.07it/s]
