Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/litdata/utilities/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def train_test_split(
streaming_dataset: StreamingDataset, splits: list[float], seed: int = 42
streaming_dataset: StreamingDataset, splits: list[float], seed: int = 42, shuffle: bool = True
) -> list[StreamingDataset]:
"""Splits a StreamingDataset into multiple subsets for training, testing, and validation.

Expand All @@ -24,6 +24,7 @@ def train_test_split(
splits: A list of floats representing the proportion of data to be allocated to each split
(e.g., [0.8, 0.1, 0.1] for 80% training, 10% testing, and 10% validation).
seed: An integer used to seed the random number generator for reproducibility.
shuffle: A boolean indicating whether to shuffle the data before splitting.

Returns:
List[StreamingDataset]: A list of StreamingDataset instances, where each element represents a split of the
Expand Down Expand Up @@ -71,9 +72,10 @@ def train_test_split(

dataset_length = sum([my_roi[1] - my_roi[0] for my_roi in dummy_subsampled_roi])

subsampled_chunks, dummy_subsampled_roi = shuffle_lists_together(
subsampled_chunks, dummy_subsampled_roi, np.random.RandomState([seed])
)
if shuffle:
subsampled_chunks, dummy_subsampled_roi = shuffle_lists_together(
subsampled_chunks, dummy_subsampled_roi, np.random.RandomState([seed])
)

item_count_list = [int(dataset_length * split) for split in splits]

Expand Down
31 changes: 31 additions & 0 deletions tests/utilities/test_train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,34 @@ def test_train_test_split_with_streaming_dataloader(tmpdir, compression):
for curr_idx in _dl:
assert curr_idx not in visited_indices
visited_indices.add(curr_idx)


@pytest.mark.parametrize(
"compression",
[
pytest.param(None),
pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")),
],
)
def test_train_test_split_with_shuffle_parameter(tmpdir, compression):
cache = Cache(str(tmpdir), chunk_size=10, compression=compression)
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

my_streaming_dataset = StreamingDataset(input_dir=str(tmpdir))

train_shuffled, test_shuffled = train_test_split(my_streaming_dataset, splits=[0.8, 0.2], shuffle=True)
train_no_shuffle, test_no_shuffle = train_test_split(my_streaming_dataset, splits=[0.8, 0.2], shuffle=False)

assert len(train_shuffled) == 80
assert len(train_no_shuffle) == 80
assert len(test_shuffled) == 20
assert len(test_no_shuffle) == 20

shuffled_combined = train_shuffled.subsampled_files + test_shuffled.subsampled_files
no_shuffle_combined = train_no_shuffle.subsampled_files + test_no_shuffle.subsampled_files
assert shuffled_combined != no_shuffle_combined

assert no_shuffle_combined == my_streaming_dataset.subsampled_files
Loading