Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import sys
import threading
from types import ModuleType
from unittest.mock import Mock

import pytest
import torch.distributed
from litdata.streaming.reader import PrepareChunksThread


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -65,3 +67,31 @@ def lightning_sdk_mock(monkeypatch):
lightning_sdk = ModuleType("lightning_sdk")
monkeypatch.setitem(sys.modules, "lightning_sdk", lightning_sdk)
return lightning_sdk


@pytest.fixture(autouse=True)
def _thread_police():
"""Attempts to stop left-over threads to avoid test interactions.

Adapted from PyTorch Lightning.

"""
active_threads_before = set(threading.enumerate())
yield
active_threads_after = set(threading.enumerate())

for thread in active_threads_after - active_threads_before:
if isinstance(thread, PrepareChunksThread):
thread.force_stop()
continue

stop = getattr(thread, "stop", None) or getattr(thread, "exit", None)
if thread.daemon and callable(stop):
# A daemon thread would anyway be stopped at the end of a program
# We do it preemptively here to reduce the risk of interactions with other tests that run after
stop()
assert not thread.is_alive()
elif thread.name == "QueueFeederThread":
thread.join(timeout=20)
else:
raise AssertionError(f"Test left zombie thread: {thread}")