-
Notifications
You must be signed in to change notification settings - Fork 80
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed
Description
🐛 Bug
The StreamingDataLoader/Dataset returns an uneven number of batches across the ranks.
Example:
import torch
from lightning.fabric import Fabric
from tqdm import tqdm
from litdata import optimize, StreamingDataLoader, StreamingDataset, TokensLoader
def tokenize(item):
size = torch.randint(10, 20, size=(1, )).item()
yield torch.randint(0, 1000, size=(size, ))
def get_dataloader():
train_dataset = StreamingDataset(
input_dir="data/fake-data",
item_loader=TokensLoader(block_size=10),
# shuffle=True,
# drop_last=True,
)
train_dataloader = StreamingDataLoader(
train_dataset,
batch_size=2,
num_workers=1,
# drop_last seems to have an influence here:
drop_last=True
)
return train_dataloader
def main():
torch.manual_seed(42)
fabric = Fabric(accelerator="cpu", devices=4)
fabric.launch()
if fabric.global_rank == 0:
optimize(
fn=tokenize,
inputs=list(range(100)),
output_dir="data/fake-data",
num_workers=2,
chunk_size=100,
mode="overwrite"
)
fabric.barrier()
train_dataloader = get_dataloader()
# print(f"Rank {fabric.global_rank}: Length = {len(train_dataloader)}")
fabric.barrier()
print("Start fetching")
monitor = tqdm if fabric.global_rank == 0 else lambda x: x
batches_fetched = 0
for _ in monitor(train_dataloader):
batches_fetched += 1
pass
print(f"Rank {fabric.global_rank} finished. Batches fetched: {batches_fetched}, length: {len(train_dataloader)}")
fabric.barrier()
if __name__ == "__main__":
main()Output:
Rank 0 finished. Batches fetched: 16, length: 16
Rank 1 finished. Batches fetched: 17, length: 16
Rank 3 finished. Batches fetched: 17, length: 16
Rank 2 finished. Batches fetched: 17, length: 16
As you can see, counting the batches on each rank shows uneven amounts. However, the len(dataloader) seems to return the correct value.
The docs on drop_last state:
drop_last: If `True`, drops the last items to ensure that
all processes/workers return the same amount of data.
The argument `drop_last` is set to `True` in a distributed setting
and `False` otherwise.
At the moment, this doesn't seem to work as described.
csy1204 and tchaton
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed