-
Notifications
You must be signed in to change notification settings - Fork 80
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed
Description
🐛 Bug
A checkpoint saved from a StreamingDataLoader on multiple ranks can't be resumed at the moment.
from lightning import Fabric
import torch
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader
def get_dataloader():
train_dataset = StreamingDataset(
input_dir="data/openwebtext/optimized/train",
item_loader=TokensLoader(block_size=1025),
shuffle=True,
)
train_dataloader = StreamingDataLoader(
train_dataset, batch_size=12, pin_memory=True, num_workers=8, drop_last=True
)
return train_dataloader
def dataloader_state():
"""Dataloader state extracted from the real checkpoint, for reproduction"""
checkpoint = torch.load("checkpoint.pt", weights_only=False)
state_dict = checkpoint["train_dataloader"]
return state_dict
if __name__ == "__main__":
fabric = Fabric(accelerator="cpu", devices=8)
fabric.launch()
train_dataloader = get_dataloader()
train_dataloader.load_state_dict(dataloader_state())
iterator = iter(train_dataloader)
try:
batch = next(iterator)
print("No exception on rank", fabric.global_rank)
except Exception as e:
print("Exception on rank", fabric.global_rank)Output:
Exception on rank 4
Exception on rank 2
Exception on rank 5
Exception on rank 6
Exception on rank 1
Exception on rank 3
Exception on rank 7
No exception on rank 0
The exception is on rank > 0 is:
[rank6]: Traceback (most recent call last):
[rank6]: File "/teamspace/studios/this_studio/repro.py", line 37, in <module>
[rank6]: raise e
[rank6]: File "/teamspace/studios/this_studio/repro.py", line 34, in <module>
[rank6]: batch = next(iterator)
[rank6]: File "/teamspace/studios/this_studio/src/litdata/src/litdata/streaming/dataloader.py", line 623, in __iter__
[rank6]: for batch in super().__iter__():
[rank6]: File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
[rank6]: data = self._next_data()
[rank6]: File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
[rank6]: return self._process_data(data)
[rank6]: File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
[rank6]: data.reraise()
[rank6]: File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/_utils.py", line 706, in reraise
[rank6]: raise exception
[rank6]: KeyError: Caught KeyError in DataLoader worker process 0.
[rank6]: Original Traceback (most recent call last):
[rank6]: File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 253, in _worker_loop
[rank6]: fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank6]: File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 80, in create_fetcher
[rank6]: return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank6]: File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 22, in __init__
[rank6]: self.dataset_iter = iter(dataset)
[rank6]: File "/teamspace/studios/this_studio/src/litdata/src/litdata/streaming/dataset.py", line 232, in __iter__
[rank6]: self._resume(workers_chunks, workers_intervals)
[rank6]: File "/teamspace/studios/this_studio/src/litdata/src/litdata/streaming/dataset.py", line 289, in _resume
[rank6]: chunks_index, indexes = _replay_chunks_sampling(
[rank6]: File "/teamspace/studios/this_studio/src/litdata/src/litdata/streaming/dataset.py", line 534, in _replay_chunks_sampling
[rank6]: if indexes[worker_idx] >= size:
[rank6]: KeyError: 48
Recently, we changed the worker_idx to be global across all ranks.
repro.py and checkpoint.pt available in this Studio.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed