Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed May 24, 2023
1 parent 6d29ee4 commit d2442d6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
16 changes: 12 additions & 4 deletions olmo/data/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
class IterableDataset(torch.utils.data.IterableDataset[T]):
"""
Adapted from PyTorch's DistributedSampler, this wraps a Dataset or arbitrary sequence
as an IterableDataset that can be deterministically restarted at any point by setting `start_step`.
as an IterableDataset that can be deterministically restarted at any point by setting `start_step`,
which should be a multiple of your per-device batch size.
Similarly `max_steps`, if set, should be a multiple of per-device batch size.
"""

def __init__(
Expand All @@ -27,16 +29,22 @@ def __init__(
start_step: int = 0,
max_steps: Optional[int] = None,
shuffle: bool = True,
drop_last: bool = False
drop_last: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
):
self.dataset = dataset
self.seed = seed
self.start_step = start_step
self.max_steps = max_steps
self.shuffle = shuffle
self.drop_last = drop_last
self.rank = global_rank()
self.world_size = dist.get_world_size() if (dist.is_available() and dist.is_initialized()) else 1
self.rank = rank if rank is not None else global_rank()
self.world_size = (
world_size
if world_size is not None
else (dist.get_world_size() if (dist.is_available() and dist.is_initialized()) else 1)
)
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type]
Expand Down
30 changes: 30 additions & 0 deletions tests/data/iterable_dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from olmo.data import IterableDataset


def test_iterable_dataset_size():
dataset = IterableDataset(list(range(20)), world_size=2, rank=0, shuffle=False)
assert dataset.total_size == 20
assert list(dataset) == list(range(0, 20, 2))

dataset = IterableDataset(list(range(20)), world_size=3, rank=0, shuffle=False, drop_last=False)
assert dataset.total_size == 21
assert list(dataset) == list(range(0, 20, 3))

dataset = IterableDataset(list(range(20)), world_size=3, rank=2, shuffle=False, drop_last=False)
assert list(dataset) == list(range(2, 18, 3)) + [0]

dataset = IterableDataset(list(range(20)), world_size=3, rank=0, shuffle=False, drop_last=True)
assert dataset.total_size == 18
assert list(dataset) == list(range(0, 18, 3))


def test_iterable_dataset_max_steps():
batch_size = 2
dataset = IterableDataset(list(range(20)), world_size=2, rank=0, shuffle=False, max_steps=batch_size * 3)
assert list(dataset) == [0, 2, 4, 6, 8, 10]


def test_iterable_dataset_start_step():
batch_size = 2
dataset = IterableDataset(list(range(20)), world_size=2, rank=0, shuffle=False, start_step=batch_size * 3)
assert list(dataset) == [12, 14, 16, 18]

0 comments on commit d2442d6

Please sign in to comment.