Skip to content

Resuming the StreamingDataLoader from a checkpoint makes the iterator hang #213

@awaelchli

Description

@awaelchli

🐛 Bug

In LitGPT, we currently aren't able to resume correctly from a pretrained checkpoint. The state the dataloader loads causes a stall in the iterator.

To Reproduce

Studio for reproduction.

I made a minimal repro without LitGPT involved:

import torch
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader


def get_dataloader():
    slim_train_data = StreamingDataset(
        input_dir="s3://tinyllama-template/slimpajama/train",
        item_loader=TokensLoader(block_size=2049),
        shuffle=True,
        drop_last=True,
    )
    train_data = slim_train_data

    train_datasets = [
        slim_train_data,
        StreamingDataset(
            input_dir="s3://tinyllama-template/starcoder",
            item_loader=TokensLoader(block_size=2049),
            shuffle=True,
            drop_last=True,
        ),
    ]

    # Mix SlimPajama data and Starcoder data with these proportions:
    weights = (0.693584, 0.306416)
    train_data = CombinedStreamingDataset(
        datasets=train_datasets, seed=42, weights=weights, iterate_over_all=False
    )
    train_dataloader = StreamingDataLoader(
        train_data, batch_size=4, pin_memory=True, num_workers=8, drop_last=True
    )
    return train_dataloader


def dataloader_state():
    """Dataloader state extracted from the real checkpoint, for reproduction"""
    return {'dataset': {'0': {'num_samples_yielded': 852336, 'num_workers': 8, 'batch_size': 4, 'current_epoch': 1, 'input_dir_path': None, 'input_dir_url': 's3://tinyllama-template/slimpajama/train', 'item_loader': {'block_size': 2049}, 'drop_last': True, 'seed': 42, 'world_size': 1, 'shuffle': True}, '1': {'num_samples_yielded': 376464, 'num_workers': 8, 'batch_size': 4, 'current_epoch': 1, 'input_dir_path': None, 'input_dir_url': 's3://tinyllama-template/starcoder', 'item_loader': {'block_size': 2049}, 'drop_last': True, 'seed': 42, 'world_size': 1, 'shuffle': True}}, 'current_epoch': 0, 'latest_worker_idx': 7, 'num_samples_yielded': {0: [106542, 47058], 1: [106542, 47058], 2: [106542, 47058], 3: [106542, 47058], 4: [106542, 47058], 5: [106542, 47058], 6: [106542, 47058], 7: [106542, 47058]}}


if __name__ == "__main__":
    train_dataloader = get_dataloader()
    # checkpoint = torch.load("dataloader-state.pt")
    train_dataloader.load_state_dict(dataloader_state())
    
    print("start")
    for batch in train_dataloader:  # hangs here
        print(batch)

If executed in a Studio, to avoid authentication errors, the following patch is needed in /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/litdata/streaming/client.py, line 40:

        if has_shared_credentials_file or not _IS_IN_STUDIO or True:
            self._client = boto3.client(
                "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}, signature_version=botocore.UNSIGNED)
            )

Expected behavior

Continue fetching data from where we left off.

Environment

LitData: 0.2.15

cc @sanyalsunny111

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedExtra attention is neededpriority 0

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions