Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resuming Training with New Dataset Fails #263

Closed
schopra8 opened this issue Jul 24, 2024 · 6 comments · Fixed by #318
Closed

Resuming Training with New Dataset Fails #263

schopra8 opened this issue Jul 24, 2024 · 6 comments · Fixed by #318
Assignees
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@schopra8
Copy link

🐛 Bug

If you train a model with a particular dataset for N epochs and then want to continue training with a new dataset, LitData throws an exception.

To Reproduce

Steps to reproduce the behavior:

  1. Train a model with dataset-1
  2. Cancel training after the first checkpoint is aved
  3. Resume training with trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) where the datamodule now points to dataset-2.
  4. Capture the following error
[rank7]: Original Traceback (most recent call last):
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank7]:     fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank7]:     return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank7]:     self.dataset_iter = iter(dataset)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 155, in __iter__
[rank7]:     self._iterator = _CombinedDatasetIterator(
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in __init__
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in <listcomp>
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 219, in __iter__
[rank7]:     self._validate_state_dict()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 447, in _validate_state_dict
[rank7]:     raise ValueError(
[rank7]: ValueError: The provided input_dir URL state doesn't match the current one. Found s3://dataset-2 instead of s3://dataset-1.

Code sample

Expected behavior

Training to start with the optimizer states, model weights, etc. but with a net new dataset.

Environment

  • PyTorch Version (e.g., 1.0): 2.3.1
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.10
  • CUDA/cuDNN version: 12.1
  • GPU models and configuration: 2x8H100
  • Any other relevant information:

Additional context

@tchaton
Copy link
Collaborator

tchaton commented Jul 24, 2024

Hey @schopra8,

Did you consume exactly N epochs of the first dataset ?

As a temporary hack, did you try dropping the dataloader state from the checkpoint before reloading it. This might unblock you.

We should enable dropping the state entirely when the epoch is terminated to enable reloading with different parameters. Maybe we can do it there: https://github.com/Lightning-AI/litdata/blob/main/src/litdata/streaming/dataloader.py#L468.

Would you be interested in trying to contribute a fix ?

@schopra8
Copy link
Author

schopra8 commented Jul 24, 2024

In reality, I'd use this after exactly N epochs for the first dataset was consumed. In this dummy example, I manually killed training on the first dataset after K steps (less than 1 epoch) and then tried changing the datasets. Would that require a different change to the codebase?

I'll try the hack today -- but definitely down to contribute a fix.

@tchaton
Copy link
Collaborator

tchaton commented Jul 24, 2024

Yes, we need to hack around and find the right fix. Feel free to make a draft PR and we can help you land a reliable fix.

@tchaton
Copy link
Collaborator

tchaton commented Jul 29, 2024

Hey @schopra8. Any updates ?

@schopra8
Copy link
Author

@tchaton No updates on my end yet -- got busy with another modeling task. I'll be taking a crack it in the next few days

@bhimrazy
Copy link
Collaborator

Hi @schopra8,

We’ve released an update with bug fixes, including the one related to this issue. Currently, it only supports StreamingDataset. We will be adding fixes for CombinedStreamingDataset soon as well.

Please feel free to try it out and let us know how it goes. Thanks! 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants