diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 005cb54d4..6e1ff50a6 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -45,9 +45,16 @@ jobs: uv pip install -e ".[extras]" -r requirements/test.txt -U -q uv pip list - - name: Tests - working-directory: tests - run: pytest . -v --cov=litdata --durations=100 + - name: Run fast tests in parallel + run: | + pytest \ + tests/streaming tests/utilities \ + tests/test_cli.py tests/test_debugger.py \ + -n 2 --cov=litdata --cov-append --cov-report= --durations=120 + + - name: Run processing tests sequentially + run: | + pytest tests/processing tests/raw --cov=litdata --cov-append --cov-report= --durations=90 - name: Statistics continue-on-error: true diff --git a/requirements/test.txt b/requirements/test.txt index 505fe9d5c..32a6fc893 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -9,6 +9,7 @@ pytest-cov ==6.2.1 pytest-timeout ==2.4.0 pytest-rerunfailures ==15.1 pytest-random-order ==1.1.1 +pytest-xdist >=3.8.0 pandas pyarrow >=20.0.0 polars >1.0.0 diff --git a/tests/conftest.py b/tests/conftest.py index a2e54aadb..e4eb8d3ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import os import shutil +import signal import sys import threading from collections import OrderedDict @@ -16,7 +17,7 @@ from litdata.utilities.dataset_utilities import get_default_cache_dir -@pytest.fixture(autouse=True) +@pytest.fixture(autouse=True, scope="session") def teardown_process_group(): """Ensures distributed process group gets closed before the next test runs.""" yield @@ -25,9 +26,8 @@ def teardown_process_group(): @pytest.fixture(autouse=True) -def set_env(): - # Set environment variable before each test to configure BaseWorker's maximum wait time - os.environ["DATA_OPTIMIZER_TIMEOUT"] = "20" +def disable_signals(monkeypatch): + monkeypatch.setattr(signal, "signal", lambda *args, **kwargs: None) @pytest.fixture @@ -132,7 +132,7 @@ def lightning_sdk_mock(monkeypatch): return lightning_sdk -@pytest.fixture(autouse=True) +@pytest.fixture(autouse=True, scope="session") def _thread_police(): """Attempts stopping left-over threads to avoid test interactions. diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 68e8cbdc2..ddb517e4d 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -319,7 +319,7 @@ def test_dataloader_states_with_persistent_workers(tmpdir): assert count >= 25, "There should be at least 25 batches in the third epoch" -@pytest.mark.timeout(60) +@pytest.mark.timeout(90) def test_resume_dataloader_with_new_dataset(tmpdir): dataset_1_path = tmpdir.join("dataset_1") dataset_2_path = tmpdir.join("dataset_2") diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index c3c3ff806..4332bec79 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -311,7 +311,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")), ], ) -@pytest.mark.timeout(60) +@pytest.mark.timeout(90) def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir, compression): seed_everything(42) @@ -364,7 +364,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir, compr ), ], ) -@pytest.mark.timeout(60) +@pytest.mark.timeout(90) def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir, compression): seed_everything(42) @@ -412,7 +412,7 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir, comp pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")), ], ) -@pytest.mark.timeout(60) +@pytest.mark.timeout(90) def test_streaming_dataset_distributed_full_shuffle_even_multi_nodes(drop_last, tmpdir, compression): seed_everything(42) @@ -685,7 +685,7 @@ def test_dataset_for_text_tokens_multiple_workers(tmpdir): assert result == expected -@pytest.mark.timeout(60) +@pytest.mark.timeout(90) def test_dataset_for_text_tokens_with_large_block_size_multiple_workers(tmpdir): # test to reproduce ERROR: Unexpected segmentation fault encountered in worker seed_everything(42) @@ -1077,7 +1077,7 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from) -@pytest.mark.timeout(60) +@pytest.mark.timeout(90) @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") def test_dataset_valid_state(tmpdir, monkeypatch): seed_everything(42) @@ -1213,7 +1213,7 @@ def fn(remote_chunkpath: str, local_chunkpath: str): dataset._validate_state_dict() -@pytest.mark.timeout(60) +@pytest.mark.timeout(90) @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") def test_dataset_valid_state_override(tmpdir, monkeypatch): seed_everything(42) diff --git a/tests/utilities/test_env.py b/tests/utilities/test_env.py index 064d93804..352bf15c3 100644 --- a/tests/utilities/test_env.py +++ b/tests/utilities/test_env.py @@ -2,9 +2,9 @@ def test_distributed_env_from_env(monkeypatch): - monkeypatch.setenv("WORLD_SIZE", 2) - monkeypatch.setenv("GLOBAL_RANK", 1) - monkeypatch.setenv("NNODES", 2) + monkeypatch.setenv("WORLD_SIZE", "2") + monkeypatch.setenv("GLOBAL_RANK", "1") + monkeypatch.setenv("NNODES", "2") dist_env = _DistributedEnv.detect() assert dist_env.world_size == 2