diff --git a/README.md b/README.md index 003eac480..78e91a5ef 100644 --- a/README.md +++ b/README.md @@ -520,6 +520,48 @@ from litdata import StreamingDataset dataset = StreamingDataset(input_dir="local:/data/shared-drive/some-data") ``` + + +
+ ✅ Optimize dataset in distributed environment +  + +Lightning can distribute large workloads across hundreds of machines in parallel. This can reduce the time to complete a data processing task from weeks to minutes by scaling to enough machines. + +To apply the optimize operator across multiple machines, simply provide the num_nodes and machine arguments to it as follows: + +```python +import os +from litdata import optimize, Machine + +def compress(index): + return (index, index ** 2) + +optimize( + fn=compress, + inputs=list(range(100)), + num_workers=2, + output_dir="my_output", + chunk_bytes="64MB", + num_nodes=2, + machine=Machine.DATA_PREP, # You can select between dozens of optimized machines +) +``` + +If the `output_dir` is a local path, the optimized dataset will be present in: `/teamspace/jobs/{job_name}/nodes-0/my_output`. Otherwise, it will be stored in the specified `output_dir`. + +Read the optimized dataset: + +```python +from litdata import StreamingDataset + +output_dir = "/teamspace/jobs/litdata-optimize-2024-07-08/nodes.0/my_output" + +dataset = StreamingDataset(output_dir) + +print(dataset[:]) +``` +
  diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 69b1f9b1f..4741b75dd 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -139,7 +139,7 @@ def __iter__(self) -> Iterator[Any]: num_samples_yielded = None if self._num_samples_yielded is not None and worker_env.rank in self._num_samples_yielded: - num_samples_yielded = self._num_samples_yielded[worker_env.rank] + num_samples_yielded = self._num_samples_yielded.get(worker_env.rank, 0) self._iterator = _CombinedDatasetIterator( self._datasets, diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 88176ec29..92d1068b9 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -198,7 +198,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")), ], ) -@pytest.mark.timeout(30) +@pytest.mark.timeout(60) def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir, compression): seed_everything(42) @@ -807,8 +807,9 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir): return StreamingDataLoader(dataset, batch_size=2, num_workers=1) -@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") +@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) +@pytest.mark.timeout(60) def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" @@ -819,7 +820,7 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): optimize( fn=_simple_preprocess, - inputs=list(range(8)), + inputs=list(range(5)), output_dir=str(tmpdir / "optimized"), chunk_size=190, num_workers=4, @@ -830,6 +831,7 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) batches_to_fetch = 16 batch_to_resume_from = None + dataloader_state = None for i, batch in enumerate(train_dataloader): if i == batches_to_fetch: dataloader_state = train_dataloader.state_dict() @@ -840,6 +842,8 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): shutil.rmtree(s3_cache_dir) os.mkdir(s3_cache_dir) train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) + assert dataloader_state is not None + assert batch_to_resume_from is not None train_dataloader.load_state_dict(dataloader_state) # The next batch after resuming must match what we should have gotten next in the initial loop assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)