Skip to content

Uneven number of batches returned across ranks in StreamingDataset/DataLoader #233

@awaelchli

Description

@awaelchli

🐛 Bug

The StreamingDataLoader/Dataset returns an uneven number of batches across the ranks.

Example:

import torch
from lightning.fabric import Fabric
from tqdm import tqdm
from litdata import optimize, StreamingDataLoader, StreamingDataset, TokensLoader


def tokenize(item):
    size = torch.randint(10, 20, size=(1, )).item()
    yield torch.randint(0, 1000, size=(size, ))


def get_dataloader():    
    train_dataset = StreamingDataset(
        input_dir="data/fake-data",
        item_loader=TokensLoader(block_size=10),
        # shuffle=True,
        # drop_last=True,
    )
    train_dataloader = StreamingDataLoader(
        train_dataset, 
        batch_size=2, 
        num_workers=1,
        # drop_last seems to have an influence here: 
        drop_last=True
    )
    return train_dataloader


def main():
    torch.manual_seed(42)
    fabric = Fabric(accelerator="cpu", devices=4)
    fabric.launch()

    if fabric.global_rank == 0:
        optimize(
            fn=tokenize,
            inputs=list(range(100)),
            output_dir="data/fake-data",
            num_workers=2,
            chunk_size=100,
            mode="overwrite"
        )
    fabric.barrier()
    
    train_dataloader = get_dataloader()

    # print(f"Rank {fabric.global_rank}: Length = {len(train_dataloader)}")
    fabric.barrier()
    
    print("Start fetching")
    monitor = tqdm if fabric.global_rank == 0 else lambda x: x
    
    batches_fetched = 0
    for _ in monitor(train_dataloader):
        batches_fetched += 1
        pass

    print(f"Rank {fabric.global_rank} finished. Batches fetched: {batches_fetched}, length: {len(train_dataloader)}")
    fabric.barrier()


if __name__ == "__main__":
    main()

Output:

Rank 0 finished. Batches fetched: 16, length: 16
Rank 1 finished. Batches fetched: 17, length: 16
Rank 3 finished. Batches fetched: 17, length: 16
Rank 2 finished. Batches fetched: 17, length: 16

As you can see, counting the batches on each rank shows uneven amounts. However, the len(dataloader) seems to return the correct value.

The docs on drop_last state:

drop_last: If `True`, drops the last items to ensure that
	all processes/workers return the same amount of data.
	The argument `drop_last` is set to `True` in a distributed setting
	and `False` otherwise.

At the moment, this doesn't seem to work as described.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions