In [3]:
!python data/cached_fineweb10B.py 8

In [4]:
import glob, array
from pathlib import Path
import torch                          # only used for fast host-side dtype conversion

# ------------------------------------------------------------------
# 1.  How many tokens do we need?
# ------------------------------------------------------------------
NUM_GPUS                 = 8
GRADIENT_ACCUM_STEPS     = 1          # “no gradient accumulation”
TRAIN_SEQ_LEN            = 48 * 1024  # 49 152
NUM_ITERATIONS           = 1_770

TOKENS_PER_STEP  = NUM_GPUS * TRAIN_SEQ_LEN // GRADIENT_ACCUM_STEPS
TOKENS_NEEDED    = TOKENS_PER_STEP * NUM_ITERATIONS              # 695 992 320
print(f"{TOKENS_NEEDED:,} tokens will be copied into the new file")

# ------------------------------------------------------------------
# 2.  Helper to load a shard (same logic as _load_data_shard in train_gpt.py)
# ------------------------------------------------------------------
MAGIC = 20240520
VERSION = 1
HEADER_LEN = 256                    # int32 words

def load_shard(file: Path) -> torch.Tensor:
    header = torch.from_file(str(file), False, HEADER_LEN, dtype=torch.int32)
    assert header[0] == MAGIC and header[1] == VERSION, f"Bad header in {file}"
    ntok = int(header[2])
    with file.open("rb", buffering=0) as f:
        f.seek(HEADER_LEN * 4)
        buf = torch.empty(ntok, dtype=torch.uint16, pin_memory=False)  # host RAM is fine
        nread = f.readinto(buf.numpy())
        assert nread == 2 * ntok, "size mismatch"
    return buf

# ------------------------------------------------------------------
# 3.  Collect the first TOKENS_NEEDED tokens from the training shards
# ------------------------------------------------------------------
train_files = sorted(glob.glob("data/fineweb10B/fineweb_train_*.bin"))
assert train_files, "no shards found"

tokens = torch.empty(TOKENS_NEEDED, dtype=torch.uint16)
cursor = 0
for file in train_files:
    shard = load_shard(Path(file))
    take = min(shard.numel(), TOKENS_NEEDED - cursor)
    tokens[cursor:cursor+take] = shard[:take]
    cursor += take
    if cursor == TOKENS_NEEDED:
        break
assert cursor == TOKENS_NEEDED, "ran out of data before hitting the target"

# ------------------------------------------------------------------
# 4.  Write the new single-shard .bin file
# ------------------------------------------------------------------
out_path = Path("data/train_first_696M_tokens.bin")
out_path.parent.mkdir(parents=True, exist_ok=True)

header = array.array("i", [0]*HEADER_LEN)
header[0] = MAGIC
header[1] = VERSION
header[2] = TOKENS_NEEDED           # claim the real token count

with out_path.open("wb") as f:
    f.write(header.tobytes())
    f.write(tokens.numpy().tobytes())
print(f"Wrote {out_path} ({out_path.stat().st_size/1e6:.1f} MB)")

695,992,320 tokens will be copied into the new file
Wrote data/train_first_696M_tokens.bin (1392.0 MB)
