diff --git a/tests/conftest.py b/tests/conftest.py index 17e47ed85..d12ff2f62 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,11 @@ import sys +import threading from types import ModuleType from unittest.mock import Mock import pytest import torch.distributed +from litdata.streaming.reader import PrepareChunksThread @pytest.fixture(autouse=True) @@ -65,3 +67,31 @@ def lightning_sdk_mock(monkeypatch): lightning_sdk = ModuleType("lightning_sdk") monkeypatch.setitem(sys.modules, "lightning_sdk", lightning_sdk) return lightning_sdk + + +@pytest.fixture(autouse=True) +def _thread_police(): + """Attempts to stop left-over threads to avoid test interactions. + + Adapted from PyTorch Lightning. + + """ + active_threads_before = set(threading.enumerate()) + yield + active_threads_after = set(threading.enumerate()) + + for thread in active_threads_after - active_threads_before: + if isinstance(thread, PrepareChunksThread): + thread.force_stop() + continue + + stop = getattr(thread, "stop", None) or getattr(thread, "exit", None) + if thread.daemon and callable(stop): + # A daemon thread would anyway be stopped at the end of a program + # We do it preemptively here to reduce the risk of interactions with other tests that run after + stop() + assert not thread.is_alive() + elif thread.name == "QueueFeederThread": + thread.join(timeout=20) + else: + raise AssertionError(f"Test left zombie thread: {thread}")