From 8c8230606d173a55a2f84b2fbdbb48e920cbdb70 Mon Sep 17 00:00:00 2001 From: Dustin Ngo Date: Tue, 21 May 2024 16:15:16 -0400 Subject: [PATCH] fix: Bypass signal handler if running in a thread (#3251) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Bypass signal handler if running in a thread - enables full sync compatibility if `run_sync=True` * Ruff 🐶 * Refactor to use a context manager * Add tests * Add tests for executor factory when not in main thread * Fix executor factory logic --- .../src/phoenix/evals/executors.py | 99 +++++++++----- .../phoenix/evals/functions/test_executor.py | 121 ++++++++++++++++++ 2 files changed, 190 insertions(+), 30 deletions(-) diff --git a/packages/phoenix-evals/src/phoenix/evals/executors.py b/packages/phoenix-evals/src/phoenix/evals/executors.py index 6017eaccfe..a419f1e794 100644 --- a/packages/phoenix-evals/src/phoenix/evals/executors.py +++ b/packages/phoenix-evals/src/phoenix/evals/executors.py @@ -3,8 +3,21 @@ import asyncio import logging import signal +import threading import traceback -from typing import Any, Callable, Coroutine, List, Optional, Protocol, Sequence, Tuple, Union +from contextlib import contextmanager +from typing import ( + Any, + Callable, + Coroutine, + Generator, + List, + Optional, + Protocol, + Sequence, + Tuple, + Union, +) from phoenix.evals.exceptions import PhoenixException from tqdm.auto import tqdm @@ -168,7 +181,7 @@ def termination_handler(signum: int, frame: Any) -> None: termination_event.set() tqdm.write("Process was interrupted. The return value will be incomplete...") - signal.signal(self.termination_signal, termination_handler) + original_handler = signal.signal(self.termination_signal, termination_handler) outputs = [self.fallback_return_value] * len(inputs) progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format) @@ -209,7 +222,7 @@ def termination_handler(signum: int, frame: Any) -> None: termination_event_watcher.cancel() # reset the SIGTERM handler - signal.signal(self.termination_signal, signal.SIG_DFL) # reset the SIGTERM handler + signal.signal(self.termination_signal, original_handler) # reset the SIGTERM handler return outputs def run(self, inputs: Sequence[Any]) -> List[Any]: @@ -244,7 +257,7 @@ def __init__( max_retries: int = 10, exit_on_error: bool = True, fallback_return_value: Union[Unset, Any] = _unset, - termination_signal: signal.Signals = signal.SIGINT, + termination_signal: Optional[signal.Signals] = signal.SIGINT, ): self.generate = generation_fn self.fallback_return_value = fallback_return_value @@ -259,35 +272,46 @@ def _signal_handler(self, signum: int, frame: Any) -> None: tqdm.write("Process was interrupted. The return value will be incomplete...") self._TERMINATE = True + @contextmanager + def _executor_signal_handling(self, signum: Optional[int]) -> Generator[None, None, None]: + original_handler = None + if signum is not None: + original_handler = signal.signal(signum, self._signal_handler) + try: + yield + finally: + signal.signal(signum, original_handler) + else: + yield + def run(self, inputs: Sequence[Any]) -> List[Any]: - signal.signal(self.termination_signal, self._signal_handler) - outputs = [self.fallback_return_value] * len(inputs) - progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format) + with self._executor_signal_handling(self.termination_signal): + outputs = [self.fallback_return_value] * len(inputs) + progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format) - for index, input in enumerate(inputs): - try: - for attempt in range(self.max_retries + 1): - if self._TERMINATE: + for index, input in enumerate(inputs): + try: + for attempt in range(self.max_retries + 1): + if self._TERMINATE: + return outputs + try: + result = self.generate(input) + outputs[index] = result + progress_bar.update() + break + except Exception as exc: + is_phoenix_exception = isinstance(exc, PhoenixException) + if attempt >= self.max_retries or is_phoenix_exception: + raise exc + else: + tqdm.write(f"Exception in worker on attempt {attempt + 1}: {exc}") + tqdm.write("Retrying...") + except Exception as exc: + tqdm.write(f"Exception in worker: {exc}") + if self.exit_on_error: return outputs - try: - result = self.generate(input) - outputs[index] = result + else: progress_bar.update() - break - except Exception as exc: - is_phoenix_exception = isinstance(exc, PhoenixException) - if attempt >= self.max_retries or is_phoenix_exception: - raise exc - else: - tqdm.write(f"Exception in worker on attempt {attempt + 1}: {exc}") - tqdm.write("Retrying...") - except Exception as exc: - tqdm.write(f"Exception in worker: {exc}") - if self.exit_on_error: - return outputs - else: - progress_bar.update() - signal.signal(self.termination_signal, signal.SIG_DFL) # reset the SIGTERM handler return outputs @@ -300,7 +324,22 @@ def get_executor_on_sync_context( exit_on_error: bool = True, fallback_return_value: Union[Unset, Any] = _unset, ) -> Executor: - if run_sync: + if threading.current_thread() is not threading.main_thread(): + # run evals synchronously if not in the main thread + + if run_sync is False: + logger.warning( + "Async evals execution is not supported in non-main threads. Falling back to sync." + ) + return SyncExecutor( + sync_fn, + tqdm_bar_format=tqdm_bar_format, + exit_on_error=exit_on_error, + fallback_return_value=fallback_return_value, + termination_signal=None, + ) + + if run_sync is True: return SyncExecutor( sync_fn, tqdm_bar_format=tqdm_bar_format, diff --git a/packages/phoenix-evals/tests/phoenix/evals/functions/test_executor.py b/packages/phoenix-evals/tests/phoenix/evals/functions/test_executor.py index 8f4214dc20..def4a99a6a 100644 --- a/packages/phoenix-evals/tests/phoenix/evals/functions/test_executor.py +++ b/packages/phoenix-evals/tests/phoenix/evals/functions/test_executor.py @@ -1,7 +1,9 @@ import asyncio import os import platform +import queue import signal +import threading import time from unittest.mock import AsyncMock, Mock @@ -244,6 +246,33 @@ def sync_fn(x): assert results.count("test") > 100, "most inputs should not have been processed" +def test_sync_executor_defaults_sigint_handling(): + def sync_fn(x): + return signal.getsignal(signal.SIGINT) + + executor = SyncExecutor( + sync_fn, + max_retries=0, + fallback_return_value="test", + ) + res = executor.run(["test"]) + assert res[0] != signal.default_int_handler + + +def test_sync_executor_bypasses_sigint_handling_if_none(): + def sync_fn(x): + return signal.getsignal(signal.SIGINT) + + executor = SyncExecutor( + sync_fn, + max_retries=0, + fallback_return_value="test", + termination_signal=None, + ) + res = executor.run(["test"]) + assert res[0] == signal.default_int_handler + + def test_sync_executor_retries(): mock_generate = Mock(side_effect=RuntimeError("Test exception")) executor = SyncExecutor(mock_generate, max_retries=3) @@ -316,3 +345,95 @@ def executor_in_sync_context(): executor = executor_in_sync_context() assert isinstance(executor, SyncExecutor) + + +def test_executor_factory_returns_sync_in_threads(): + def sync_fn(): + pass + + async def async_fn(): + pass + + exception_log = queue.Queue() + + def run_test(): + try: + executor = get_executor_on_sync_context( + sync_fn, + async_fn, + run_sync=True, # request a sync_executor + ) + assert isinstance(executor, SyncExecutor) + assert executor.termination_signal is None + except Exception as e: + exception_log.put(e) + + test_thread = threading.Thread(target=run_test) + test_thread.start() + test_thread.join() + if not exception_log.empty(): + raise exception_log.get() + + +async def test_executor_factory_returns_sync_in_threads_even_if_async_context(): + def sync_fn(): + pass + + async def async_fn(): + pass + + exception_log = queue.Queue() + + async def run_test(): + nest_asyncio.apply() + try: + executor = get_executor_on_sync_context( + sync_fn, + async_fn, + ) + assert isinstance(executor, SyncExecutor) + assert executor.termination_signal is None + except Exception as e: + exception_log.put(e) + + def async_task(loop): + asyncio.set_event_loop(loop) + loop.run_until_complete(run_test()) + + loop = asyncio.new_event_loop() + test_thread = threading.Thread(target=async_task, args=(loop,)) + test_thread.start() + test_thread.join() + + if not exception_log.empty(): + raise exception_log.get() + + +def test_executor_factory_returns_async_not_in_thread_if_async_context(): + def sync_fn(): + pass + + async def async_fn(): + pass + + exception_log = queue.Queue() + + async def run_test(): + nest_asyncio.apply() + try: + executor = get_executor_on_sync_context( + sync_fn, + async_fn, + ) + assert isinstance(executor, AsyncExecutor) + assert executor.termination_signal is not None + except Exception as e: + exception_log.put(e) + + def async_task(): + asyncio.run(run_test()) + + async_task() + + if not exception_log.empty(): + raise exception_log.get()