Skip to content

Resuming Training w/ Streaming Dataset on DDP with Multiple Nodes FailsΒ #248

@schopra8

Description

@schopra8

πŸ› Bug

We trained a model for several epochs on multiple nodes, and we wanted to continue training with PyTorch Lightning and LitData.

βœ… When we resume training on a single device, resumption works as expected.
βœ… When we resume training on a single node with N devices, resumption works as expected.
❌ When we resume training on multiple nodes with N devices, resumption fails.

To Reproduce

Run trainer.fit with an existing checkpoint with DDP on multiple devices:

StackTrace:

[rank7]: Traceback (most recent call last):
[rank7]:   File "/home/train.py", line 70, in <module>
[rank7]:     main(config)
[rank7]:   File "/home/train.py", line 50, in main
[rank7]:     trainer.fit(model, datamodule=custom_data_module, ckpt_path=ckpt)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
[rank7]:     call._call_and_handle_interrupt(
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
Failures:
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 133, in __next__
[rank7]:     batch = super().__next__()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 60, in __next__
[rank7]:     batch = next(self.iterator)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__
[rank7]:     out = next(self._iterator)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 78, in __next__
[rank7]:     out[i] = next(self.iterators[i])
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataloader.py", line 628, in __iter__
[rank7]:     for batch in super().__iter__():
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
[rank7]:     data = self._next_data()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1326, in _next_data
[rank7]:     return self._process_data(data)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
[rank7]:     data.reraise()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/_utils.py", line 705, in reraise
[rank7]:     raise exception
[rank7]: IndexError: Caught IndexError in DataLoader worker process 1.
[rank7]: Original Traceback (most recent call last):
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank7]:     fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank7]:     return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank7]:     self.dataset_iter = iter(dataset)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 144, in __iter__
[rank7]:     self._iterator = _CombinedDatasetIterator(
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 192, in __init__
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 192, in <listcomp>
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 211, in __iter__
[rank7]:     self._resume(chunks_replica, intervals_replica)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 272, in _resume
[rank7]:     interval = self.worker_intervals[self.chunk_index]
[rank7]: IndexError: list index out of range

Code sample

I've scrubbed my code below --

# Create dataset module
custom_data_module = CustomDataModule(config)

# Initialize the model
model = Model()

# Define the PyTorch Lightning Trainer
wandb_logger = WandbLogger(**config.wandb_logger)
device_stats_monitor = DeviceStatsMonitor()
strategy = DDPStrategy(find_unused_parameters=True)
trainer = Trainer(
    logger=wandb_logger,
    callbacks=[device_stats_monitor],
    strategy=strategy,
    max_epochs=4,
    val_check_interval=500,
    accelerator='gpu',
    devices=8,
    num_nodes=2,
    enable_progress_bar=True,
    log_every_n_steps=50,
    precision=32,
    default_root_dir='/scratch/lightning_logs'
)
trainer.fit(model, datamodule=custom_data_module, ckpt_path=ckpt)
class CustomDataModule(LightningDataModule):
    """
    Custom Data Module wraps training/validation StreamingDataset objects.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config

    def setup(self, stage=None):
        if stage in (None, 'fit'):
            # Create train datasets
            self.train_datasets = []
            for ds_config in self.config.datasets.train:
                dataset = StreamingDataset(
                    input_dir=ds_config.path,
                    subsample=ds_config.subsample
                )
                self.train_datasets.append(dataset)
            self.train_dataset = CombinedStreamingDataset(self.train_datasets)

            # Create validation datasets
            self.val_datasets = []
            for ds_config in self.config.datasets.val:
                dataset = StreamingDataset(
                    input_dir=ds_config.path,
                    subsample=ds_config.subsample
                )
                self.val_datasets.append(dataset)
            self.val_dataset = CombinedStreamingDataset(self.val_datasets)

    def train_dataloader(self):
        return StreamingDataLoader(
            self.train_dataset,
            collate_fn=collate_fn,
            **self.config.dataloader
        )

    def val_dataloader(self):
        return StreamingDataLoader(
            self.val_dataset,
            collate_fn=collate_fn,
            **self.config.dataloader
        )

Expected behavior

Resume training on multiple nodes

Environment

  • PyTorch Version (e.g., 1.0): 2.3.1
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): poetry
  • Build command you used (if compiling from source):
  • Python version: 3.10
  • CUDA/cuDNN version: 12.1
  • GPU models and configuration: 2 8xH100 nodes
  • Any other relevant information:

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is neededpriority 0

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions