Skip to content

Index error in replay chunks on world size > 0 #251

@awaelchli

Description

@awaelchli

🐛 Bug

A checkpoint saved from a StreamingDataLoader on multiple ranks can't be resumed at the moment.

from lightning import Fabric
import torch
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader


def get_dataloader():
    train_dataset = StreamingDataset(
        input_dir="data/openwebtext/optimized/train",
        item_loader=TokensLoader(block_size=1025),
        shuffle=True,
    )
    train_dataloader = StreamingDataLoader(
        train_dataset, batch_size=12, pin_memory=True, num_workers=8, drop_last=True
    )
    return train_dataloader


def dataloader_state():
    """Dataloader state extracted from the real checkpoint, for reproduction"""
    checkpoint = torch.load("checkpoint.pt", weights_only=False)
    state_dict = checkpoint["train_dataloader"]
    return state_dict


if __name__ == "__main__":
    fabric = Fabric(accelerator="cpu", devices=8)
    fabric.launch()

    train_dataloader = get_dataloader()
    train_dataloader.load_state_dict(dataloader_state())
    
    iterator = iter(train_dataloader)
    try:
        batch = next(iterator)
        print("No exception on rank", fabric.global_rank)
    except Exception as e:
        print("Exception on rank", fabric.global_rank)

Output:

Exception on rank 4
Exception on rank 2
Exception on rank 5
Exception on rank 6
Exception on rank 1
Exception on rank 3
Exception on rank 7
No exception on rank 0

The exception is on rank > 0 is:

[rank6]: Traceback (most recent call last):
[rank6]:   File "/teamspace/studios/this_studio/repro.py", line 37, in <module>
[rank6]:     raise e
[rank6]:   File "/teamspace/studios/this_studio/repro.py", line 34, in <module>
[rank6]:     batch = next(iterator)
[rank6]:   File "/teamspace/studios/this_studio/src/litdata/src/litdata/streaming/dataloader.py", line 623, in __iter__
[rank6]:     for batch in super().__iter__():
[rank6]:   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
[rank6]:     data = self._next_data()
[rank6]:   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
[rank6]:     return self._process_data(data)
[rank6]:   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
[rank6]:     data.reraise()
[rank6]:   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/_utils.py", line 706, in reraise
[rank6]:     raise exception
[rank6]: KeyError: Caught KeyError in DataLoader worker process 0.
[rank6]: Original Traceback (most recent call last):
[rank6]:   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 253, in _worker_loop
[rank6]:     fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank6]:   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 80, in create_fetcher
[rank6]:     return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank6]:   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 22, in __init__
[rank6]:     self.dataset_iter = iter(dataset)
[rank6]:   File "/teamspace/studios/this_studio/src/litdata/src/litdata/streaming/dataset.py", line 232, in __iter__
[rank6]:     self._resume(workers_chunks, workers_intervals)
[rank6]:   File "/teamspace/studios/this_studio/src/litdata/src/litdata/streaming/dataset.py", line 289, in _resume
[rank6]:     chunks_index, indexes = _replay_chunks_sampling(
[rank6]:   File "/teamspace/studios/this_studio/src/litdata/src/litdata/streaming/dataset.py", line 534, in _replay_chunks_sampling
[rank6]:     if indexes[worker_idx] >= size:
[rank6]: KeyError: 48

Recently, we changed the worker_idx to be global across all ranks.

repro.py and checkpoint.pt available in this Studio.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions