## Scaling Step Counts

For chapter 4, we not only want to find a more effective way to build a MiniPile, but we also want to downsize MiniPile further. The former produces datasets of roughly the same example count, and dimensionality, but the latter strictly doesn't.

For downsizing MiniPile and evaluating these downsized datasets correctly, we need to adjust the step count for the training. We will use the same calculations as performed in chapter 2 to find the optimal step count for the training of the MiniPile:

In [None]:
import pyarrow.parquet as pq
from pathlib import Path
import os

In [None]:
base_dir = '/vol/tmp/koppelmm'

In [9]:
def get_uncomp_dataset_size(dataset_path):
    size_bytes = 0
    for root, _, files in os.walk(dataset_path):
        for file in files:
            if file.endswith('.parquet'):
                # Only evaluate what's really part of the dataset
                pq_table = pq.read_table(os.path.join(root, file))
                size_bytes += sum(col.nbytes for col in pq_table.columns)
    return size_bytes

dataset_path = Path(base_dir) / 'MiniPile_DensityNano'
uncompressed_size = get_uncomp_dataset_size(dataset_path)
print(f"Uncompressed byte size of the Parquet dataset: {uncompressed_size} bytes")

# I use the byte sizes as proxy for the number of tokens, as both datasets will get tokenized with the same tokenizer
minipile_train_bytes = uncompressed_size
pile_train_bytes = 824546807506   # see https://huggingface.co/datasets/EleutherAI/the_pile_deduplicated/blob/main/dataset_infos.json
pile_effective_epochs = 1.5       # this many epochs are actually trained in the original model (calculation isn't affected, training params below are)

scale_factor = (pile_train_bytes * pile_effective_epochs) / (minipile_train_bytes * pile_effective_epochs)
print(f"Byte-based scale factor: {scale_factor:10.6f}x")
print(f"MiniPile (scaled) Train-Iters/LR-Decay-Iters: {143000 / scale_factor:.3f} ~ {round(143000 / scale_factor)}")

Uncompressed byte size of the Parquet dataset: 4774049161 bytes
Byte-based scale factor: 172.714352x
MiniPile (scaled) Train-Iters/LR-Decay-Iters: 827.957 ~ 828
