diff --git a/olmo/data/iterable_dataset.py b/olmo/data/iterable_dataset.py index dbb33fec1..79b5c8381 100644 --- a/olmo/data/iterable_dataset.py +++ b/olmo/data/iterable_dataset.py @@ -16,7 +16,9 @@ class IterableDataset(torch.utils.data.IterableDataset[T]): """ Adapted from PyTorch's DistributedSampler, this wraps a Dataset or arbitrary sequence - as an IterableDataset that can be deterministically restarted at any point by setting `start_step`. + as an IterableDataset that can be deterministically restarted at any point by setting `start_step`, + which should be a multiple of your per-device batch size. + Similarly `max_steps`, if set, should be a multiple of per-device batch size. """ def __init__( @@ -27,7 +29,9 @@ def __init__( start_step: int = 0, max_steps: Optional[int] = None, shuffle: bool = True, - drop_last: bool = False + drop_last: bool = False, + world_size: Optional[int] = None, + rank: Optional[int] = None, ): self.dataset = dataset self.seed = seed @@ -35,8 +39,12 @@ def __init__( self.max_steps = max_steps self.shuffle = shuffle self.drop_last = drop_last - self.rank = global_rank() - self.world_size = dist.get_world_size() if (dist.is_available() and dist.is_initialized()) else 1 + self.rank = rank if rank is not None else global_rank() + self.world_size = ( + world_size + if world_size is not None + else (dist.get_world_size() if (dist.is_available() and dist.is_initialized()) else 1) + ) # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type] diff --git a/tests/data/iterable_dataset_test.py b/tests/data/iterable_dataset_test.py new file mode 100644 index 000000000..9789845c6 --- /dev/null +++ b/tests/data/iterable_dataset_test.py @@ -0,0 +1,30 @@ +from olmo.data import IterableDataset + + +def test_iterable_dataset_size(): + dataset = IterableDataset(list(range(20)), world_size=2, rank=0, shuffle=False) + assert dataset.total_size == 20 + assert list(dataset) == list(range(0, 20, 2)) + + dataset = IterableDataset(list(range(20)), world_size=3, rank=0, shuffle=False, drop_last=False) + assert dataset.total_size == 21 + assert list(dataset) == list(range(0, 20, 3)) + + dataset = IterableDataset(list(range(20)), world_size=3, rank=2, shuffle=False, drop_last=False) + assert list(dataset) == list(range(2, 18, 3)) + [0] + + dataset = IterableDataset(list(range(20)), world_size=3, rank=0, shuffle=False, drop_last=True) + assert dataset.total_size == 18 + assert list(dataset) == list(range(0, 18, 3)) + + +def test_iterable_dataset_max_steps(): + batch_size = 2 + dataset = IterableDataset(list(range(20)), world_size=2, rank=0, shuffle=False, max_steps=batch_size * 3) + assert list(dataset) == [0, 2, 4, 6, 8, 10] + + +def test_iterable_dataset_start_step(): + batch_size = 2 + dataset = IterableDataset(list(range(20)), world_size=2, rank=0, shuffle=False, start_step=batch_size * 3) + assert list(dataset) == [12, 14, 16, 18]