In [79]:
%load_ext autoreload

import json
import numpy as np
from os.path import join

import torch

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [63]:
data_root = "/mnt/ssd/data/fineweb-edu-10BT/llama-tokenizer-debug"
mode = "train"  # train|val|test
batch_size = 4
sequence_length = 128

# Load dataset metadata
metadata_path = join(data_root, 'metadata.json')
with open(metadata_path, 'r') as f:
    metadata = json.load(f)

shard_id = metadata[f"{mode}_shards"][1]
shard = np.load(join(data_root, f"tokens_shard_{shard_id:03d}.npy"))  # (n,)

if mode == "train":
    # Call shuffle
    bos_token = 1
    bos_indices = np.where(shard == bos_token)[0]
    shard_splits = np.split(shard, bos_indices[1:])
    np.random.shuffle(shard_splits)
    shard = np.concatenate(shard_splits)  # (n,)

In [64]:
_num_sequences = len(shard) // sequence_length
if len(shard) % sequence_length == 0:
    _num_sequences -= 1
_num_batches = np.ceil(_num_sequences / batch_size).astype(int)

In [88]:
metadata

{'dataset_name': 'fineweb-edu-10BT',
 'tokenizer_name': 'hf-internal-testing/llama-tokenizer',
 'max_tokens_per_shard': 10000,
 'num_training_tokens': 36963,
 'train_shards': [1, 0, 3, 2],
 'validation_shards': [],
 'test_shards': [4],
 'num_shards': 5,
 'num_tokens': [9928, 9082, 9630, 8323, 9655],
 'num_documents': [8, 14, 12, 7, 10]}

In [128]:
class SingleShardDataLoader:
    """Use this only for debugging"""

    def __init__(self, data_root, mode, batch_size, sequence_length):
        self.mode = mode  # train|validation|test
        self.batch_size = batch_size
        self.sequence_length = sequence_length

        # Load dataset metadata
        with open(join(data_root, 'metadata.json'), 'r') as f:
            metadata = json.load(f)
        self.shard = np.load(join(data_root, f"tokens_shard_{metadata[f'{mode}_shards'][0]:03d}.npy"))  # (n,)

        if mode == "train":
            self.shard = self._shuffle(self.shard)

        self._num_sequences = len(self.shard) // sequence_length
        if len(self.shard) % sequence_length == 0:
            self._num_sequences -= 1
        self._num_batches = np.ceil(self._num_sequences / batch_size).astype(int)
        self._max_tokens = self._num_sequences * sequence_length + 1

        self._current_batch = 0
        self._current_epoch = 0

    def next_batch(self):
        start_index = self._current_batch * self.batch_size * self.sequence_length
        end_index = min(start_index + self.batch_size * self.sequence_length + 1, self._max_tokens)
        batch = self.shard[start_index: end_index]  # (bs * sl + 1,)
        batch = torch.from_numpy(batch.astype(np.int64))
        xs = batch[:-1].reshape(-1, self.sequence_length)  # (bs, sl)
        ys = batch[1:].reshape(-1, self.sequence_length)  # (bs, sl)

        self._current_batch += 1
        if self._current_batch >= self._num_batches and self.mode == "train":
            self._current_batch = 0
            self._current_epoch += 1
            self.shard = self._shuffle(self.shard)

        return xs, ys

    def __len__(self):
        return self._num_batches

    def reset(self):
        self._current_batch = 0
        self._current_epoch = 0
        if self.mode == "train":
            self.shard = self._shuffle(self.shard)

    @staticmethod
    def _shuffle(shard):
        bos_token = 1
        bos_indices = np.where(shard == bos_token)[0]
        shard_splits = np.split(shard, bos_indices[1:])
        np.random.shuffle(shard_splits)
        shard = np.concatenate(shard_splits)  # (n,)
        return shard

In [129]:
dataloader = SingleShardDataLoader(
    data_root="/mnt/ssd/data/fineweb-edu-10BT/llama-tokenizer-debug",
    mode="train",  # train|val|test
    batch_size=4,
    sequence_length=128
)

In [130]:
dataloader.reset()
for i in range(len(dataloader)):
    x, y = dataloader.next_batch()
    print(dataloader._current_batch)
    if dataloader._current_epoch > 0:
        break

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
0


In [127]:
dataloader._max_tokens

8961

In [113]:
() / 128

2.9453125

In [114]:
len(dataloader.shard)

9082

In [116]:
len(dataloader.shard) // 128, len(dataloader.shard) % 128

(70, 122)

In [117]:
dataloader._num_sequences

70

In [118]:
dataloader._num_batches

18

In [119]:
dataloader._num_sequences / 4

17.5