From d44449e1a309c535985ff73b3b73eeabe9d00220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 15:27:57 +0000 Subject: [PATCH] terminate threads to avoid test interactions --- tests/conftest.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 17e47ed85..d12ff2f62 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,11 @@ import sys +import threading from types import ModuleType from unittest.mock import Mock import pytest import torch.distributed +from litdata.streaming.reader import PrepareChunksThread @pytest.fixture(autouse=True) @@ -65,3 +67,31 @@ def lightning_sdk_mock(monkeypatch): lightning_sdk = ModuleType("lightning_sdk") monkeypatch.setitem(sys.modules, "lightning_sdk", lightning_sdk) return lightning_sdk + + +@pytest.fixture(autouse=True) +def _thread_police(): + """Attempts to stop left-over threads to avoid test interactions. + + Adapted from PyTorch Lightning. + + """ + active_threads_before = set(threading.enumerate()) + yield + active_threads_after = set(threading.enumerate()) + + for thread in active_threads_after - active_threads_before: + if isinstance(thread, PrepareChunksThread): + thread.force_stop() + continue + + stop = getattr(thread, "stop", None) or getattr(thread, "exit", None) + if thread.daemon and callable(stop): + # A daemon thread would anyway be stopped at the end of a program + # We do it preemptively here to reduce the risk of interactions with other tests that run after + stop() + assert not thread.is_alive() + elif thread.name == "QueueFeederThread": + thread.join(timeout=20) + else: + raise AssertionError(f"Test left zombie thread: {thread}")