-
Notifications
You must be signed in to change notification settings - Fork 80
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededpriority 0
Description
π Bug
We trained a model for several epochs on multiple nodes, and we wanted to continue training with PyTorch Lightning and LitData.
β
When we resume training on a single device, resumption works as expected.
β
When we resume training on a single node with N devices, resumption works as expected.
β When we resume training on multiple nodes with N devices, resumption fails.
To Reproduce
Run trainer.fit with an existing checkpoint with DDP on multiple devices:
StackTrace:
[rank7]: Traceback (most recent call last):
[rank7]: File "/home/train.py", line 70, in <module>
[rank7]: main(config)
[rank7]: File "/home/train.py", line 50, in main
[rank7]: trainer.fit(model, datamodule=custom_data_module, ckpt_path=ckpt)
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
[rank7]: call._call_and_handle_interrupt(
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
Failures:
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 133, in __next__
[rank7]: batch = super().__next__()
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 60, in __next__
[rank7]: batch = next(self.iterator)
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__
[rank7]: out = next(self._iterator)
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 78, in __next__
[rank7]: out[i] = next(self.iterators[i])
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataloader.py", line 628, in __iter__
[rank7]: for batch in super().__iter__():
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
[rank7]: data = self._next_data()
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1326, in _next_data
[rank7]: return self._process_data(data)
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
[rank7]: data.reraise()
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/_utils.py", line 705, in reraise
[rank7]: raise exception
[rank7]: IndexError: Caught IndexError in DataLoader worker process 1.
[rank7]: Original Traceback (most recent call last):
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank7]: fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank7]: return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank7]: self.dataset_iter = iter(dataset)
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 144, in __iter__
[rank7]: self._iterator = _CombinedDatasetIterator(
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 192, in __init__
[rank7]: self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 192, in <listcomp>
[rank7]: self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 211, in __iter__
[rank7]: self._resume(chunks_replica, intervals_replica)
[rank7]: File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 272, in _resume
[rank7]: interval = self.worker_intervals[self.chunk_index]
[rank7]: IndexError: list index out of range
Code sample
I've scrubbed my code below --
# Create dataset module
custom_data_module = CustomDataModule(config)
# Initialize the model
model = Model()
# Define the PyTorch Lightning Trainer
wandb_logger = WandbLogger(**config.wandb_logger)
device_stats_monitor = DeviceStatsMonitor()
strategy = DDPStrategy(find_unused_parameters=True)
trainer = Trainer(
logger=wandb_logger,
callbacks=[device_stats_monitor],
strategy=strategy,
max_epochs=4,
val_check_interval=500,
accelerator='gpu',
devices=8,
num_nodes=2,
enable_progress_bar=True,
log_every_n_steps=50,
precision=32,
default_root_dir='/scratch/lightning_logs'
)
trainer.fit(model, datamodule=custom_data_module, ckpt_path=ckpt)
class CustomDataModule(LightningDataModule):
"""
Custom Data Module wraps training/validation StreamingDataset objects.
"""
def __init__(self, config):
super().__init__()
self.config = config
def setup(self, stage=None):
if stage in (None, 'fit'):
# Create train datasets
self.train_datasets = []
for ds_config in self.config.datasets.train:
dataset = StreamingDataset(
input_dir=ds_config.path,
subsample=ds_config.subsample
)
self.train_datasets.append(dataset)
self.train_dataset = CombinedStreamingDataset(self.train_datasets)
# Create validation datasets
self.val_datasets = []
for ds_config in self.config.datasets.val:
dataset = StreamingDataset(
input_dir=ds_config.path,
subsample=ds_config.subsample
)
self.val_datasets.append(dataset)
self.val_dataset = CombinedStreamingDataset(self.val_datasets)
def train_dataloader(self):
return StreamingDataLoader(
self.train_dataset,
collate_fn=collate_fn,
**self.config.dataloader
)
def val_dataloader(self):
return StreamingDataLoader(
self.val_dataset,
collate_fn=collate_fn,
**self.config.dataloader
)
Expected behavior
Resume training on multiple nodes
Environment
- PyTorch Version (e.g., 1.0): 2.3.1
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip, source): poetry - Build command you used (if compiling from source):
- Python version: 3.10
- CUDA/cuDNN version: 12.1
- GPU models and configuration: 2 8xH100 nodes
- Any other relevant information:
tchaton
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededpriority 0