Skip to content

ParallelStreamingDataset with resume=True does not resume after the second epoch when breaking early #760

@philgzl

Description

@philgzl

🐛 Bug

When breaking early (e.g. when setting limit_<stage>_batches in Trainer), ParallelStreamingDataset with StreamingDataLoader does not resume from where it left off in the previous epoch after the second epoch.

The reason for this is:

  • Start of first epoch: self.restore in StreamingDataLoader is False and self.current_epoch is 0 so self.load_state_dict(self.state_dict()) is not called. Samples are yielded as expected. We break early and self.restoreis still False
  • Start of second epoch: self.restore is False and self.current_epoch is not 0 so self.load_state_dict(self.state_dict()) is called to resume. But this sets self.restore to True. Samples are yielded as expected. We break early and self.restore is still True.
  • Start of third epoch: self.restore is True so self.load_state_dict(self.state_dict()) is not called. Resuming does not happen. Same samples as in previous epoch are yielded.

To Reproduce

See code sample below.

Code sample
from litdata import ParallelStreamingDataset, StreamingDataLoader, StreamingDataset
from litdata.streaming import Cache

cache = Cache(input_dir="temp/", chunk_size=1)
dset_len = 10
for i in range(dset_len):
    cache[i] = i
cache.done()
cache.merge()

dset = ParallelStreamingDataset([StreamingDataset("temp/")], length=999, resume=True)
assert dset.is_cycling()

dloader = StreamingDataLoader(dset)

expected = 0

# epoch 1
for i, (batch,) in enumerate(dloader):
    assert batch == expected, (batch, expected)  # succeeds
    expected = (expected + 1) % dset_len
    if i == 3:
        break

# epoch 2
for i, (batch,) in enumerate(dloader):
    assert batch == expected, (batch, expected)  # succeeds
    expected = (expected + 1) % dset_len
    if i == 3:
        break

# epoch 3
for i, (batch,) in enumerate(dloader):
    assert batch == expected, (batch, expected)  # fails; repeats epoch 2
    expected = (expected + 1) % dset_len
    if i == 3:
        break

Expected behavior

The third epoch should resume from there the second epoch left off. It should not yield the same samples as the second epoch.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions