Skip to content

Commit

Permalink
fix: Bypass signal handler if running in a thread (#3251)
Browse files Browse the repository at this point in the history
* Bypass signal handler if running in a thread

- enables full sync compatibility if `run_sync=True`

* Ruff 🐶

* Refactor to use a context manager

* Add tests

* Add tests for executor factory when not in main thread

* Fix executor factory logic
  • Loading branch information
anticorrelator committed May 21, 2024
1 parent 3f58c4b commit 8c82306
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 30 deletions.
99 changes: 69 additions & 30 deletions packages/phoenix-evals/src/phoenix/evals/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,21 @@
import asyncio
import logging
import signal
import threading
import traceback
from typing import Any, Callable, Coroutine, List, Optional, Protocol, Sequence, Tuple, Union
from contextlib import contextmanager
from typing import (
Any,
Callable,
Coroutine,
Generator,
List,
Optional,
Protocol,
Sequence,
Tuple,
Union,
)

from phoenix.evals.exceptions import PhoenixException
from tqdm.auto import tqdm
Expand Down Expand Up @@ -168,7 +181,7 @@ def termination_handler(signum: int, frame: Any) -> None:
termination_event.set()
tqdm.write("Process was interrupted. The return value will be incomplete...")

signal.signal(self.termination_signal, termination_handler)
original_handler = signal.signal(self.termination_signal, termination_handler)
outputs = [self.fallback_return_value] * len(inputs)
progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format)

Expand Down Expand Up @@ -209,7 +222,7 @@ def termination_handler(signum: int, frame: Any) -> None:
termination_event_watcher.cancel()

# reset the SIGTERM handler
signal.signal(self.termination_signal, signal.SIG_DFL) # reset the SIGTERM handler
signal.signal(self.termination_signal, original_handler) # reset the SIGTERM handler
return outputs

def run(self, inputs: Sequence[Any]) -> List[Any]:
Expand Down Expand Up @@ -244,7 +257,7 @@ def __init__(
max_retries: int = 10,
exit_on_error: bool = True,
fallback_return_value: Union[Unset, Any] = _unset,
termination_signal: signal.Signals = signal.SIGINT,
termination_signal: Optional[signal.Signals] = signal.SIGINT,
):
self.generate = generation_fn
self.fallback_return_value = fallback_return_value
Expand All @@ -259,35 +272,46 @@ def _signal_handler(self, signum: int, frame: Any) -> None:
tqdm.write("Process was interrupted. The return value will be incomplete...")
self._TERMINATE = True

@contextmanager
def _executor_signal_handling(self, signum: Optional[int]) -> Generator[None, None, None]:
original_handler = None
if signum is not None:
original_handler = signal.signal(signum, self._signal_handler)
try:
yield
finally:
signal.signal(signum, original_handler)
else:
yield

def run(self, inputs: Sequence[Any]) -> List[Any]:
signal.signal(self.termination_signal, self._signal_handler)
outputs = [self.fallback_return_value] * len(inputs)
progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format)
with self._executor_signal_handling(self.termination_signal):
outputs = [self.fallback_return_value] * len(inputs)
progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format)

for index, input in enumerate(inputs):
try:
for attempt in range(self.max_retries + 1):
if self._TERMINATE:
for index, input in enumerate(inputs):
try:
for attempt in range(self.max_retries + 1):
if self._TERMINATE:
return outputs
try:
result = self.generate(input)
outputs[index] = result
progress_bar.update()
break
except Exception as exc:
is_phoenix_exception = isinstance(exc, PhoenixException)
if attempt >= self.max_retries or is_phoenix_exception:
raise exc
else:
tqdm.write(f"Exception in worker on attempt {attempt + 1}: {exc}")
tqdm.write("Retrying...")
except Exception as exc:
tqdm.write(f"Exception in worker: {exc}")
if self.exit_on_error:
return outputs
try:
result = self.generate(input)
outputs[index] = result
else:
progress_bar.update()
break
except Exception as exc:
is_phoenix_exception = isinstance(exc, PhoenixException)
if attempt >= self.max_retries or is_phoenix_exception:
raise exc
else:
tqdm.write(f"Exception in worker on attempt {attempt + 1}: {exc}")
tqdm.write("Retrying...")
except Exception as exc:
tqdm.write(f"Exception in worker: {exc}")
if self.exit_on_error:
return outputs
else:
progress_bar.update()
signal.signal(self.termination_signal, signal.SIG_DFL) # reset the SIGTERM handler
return outputs


Expand All @@ -300,7 +324,22 @@ def get_executor_on_sync_context(
exit_on_error: bool = True,
fallback_return_value: Union[Unset, Any] = _unset,
) -> Executor:
if run_sync:
if threading.current_thread() is not threading.main_thread():
# run evals synchronously if not in the main thread

if run_sync is False:
logger.warning(
"Async evals execution is not supported in non-main threads. Falling back to sync."
)
return SyncExecutor(
sync_fn,
tqdm_bar_format=tqdm_bar_format,
exit_on_error=exit_on_error,
fallback_return_value=fallback_return_value,
termination_signal=None,
)

if run_sync is True:
return SyncExecutor(
sync_fn,
tqdm_bar_format=tqdm_bar_format,
Expand Down
121 changes: 121 additions & 0 deletions packages/phoenix-evals/tests/phoenix/evals/functions/test_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import os
import platform
import queue
import signal
import threading
import time
from unittest.mock import AsyncMock, Mock

Expand Down Expand Up @@ -244,6 +246,33 @@ def sync_fn(x):
assert results.count("test") > 100, "most inputs should not have been processed"


def test_sync_executor_defaults_sigint_handling():
def sync_fn(x):
return signal.getsignal(signal.SIGINT)

executor = SyncExecutor(
sync_fn,
max_retries=0,
fallback_return_value="test",
)
res = executor.run(["test"])
assert res[0] != signal.default_int_handler


def test_sync_executor_bypasses_sigint_handling_if_none():
def sync_fn(x):
return signal.getsignal(signal.SIGINT)

executor = SyncExecutor(
sync_fn,
max_retries=0,
fallback_return_value="test",
termination_signal=None,
)
res = executor.run(["test"])
assert res[0] == signal.default_int_handler


def test_sync_executor_retries():
mock_generate = Mock(side_effect=RuntimeError("Test exception"))
executor = SyncExecutor(mock_generate, max_retries=3)
Expand Down Expand Up @@ -316,3 +345,95 @@ def executor_in_sync_context():

executor = executor_in_sync_context()
assert isinstance(executor, SyncExecutor)


def test_executor_factory_returns_sync_in_threads():
def sync_fn():
pass

async def async_fn():
pass

exception_log = queue.Queue()

def run_test():
try:
executor = get_executor_on_sync_context(
sync_fn,
async_fn,
run_sync=True, # request a sync_executor
)
assert isinstance(executor, SyncExecutor)
assert executor.termination_signal is None
except Exception as e:
exception_log.put(e)

test_thread = threading.Thread(target=run_test)
test_thread.start()
test_thread.join()
if not exception_log.empty():
raise exception_log.get()


async def test_executor_factory_returns_sync_in_threads_even_if_async_context():
def sync_fn():
pass

async def async_fn():
pass

exception_log = queue.Queue()

async def run_test():
nest_asyncio.apply()
try:
executor = get_executor_on_sync_context(
sync_fn,
async_fn,
)
assert isinstance(executor, SyncExecutor)
assert executor.termination_signal is None
except Exception as e:
exception_log.put(e)

def async_task(loop):
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())

loop = asyncio.new_event_loop()
test_thread = threading.Thread(target=async_task, args=(loop,))
test_thread.start()
test_thread.join()

if not exception_log.empty():
raise exception_log.get()


def test_executor_factory_returns_async_not_in_thread_if_async_context():
def sync_fn():
pass

async def async_fn():
pass

exception_log = queue.Queue()

async def run_test():
nest_asyncio.apply()
try:
executor = get_executor_on_sync_context(
sync_fn,
async_fn,
)
assert isinstance(executor, AsyncExecutor)
assert executor.termination_signal is not None
except Exception as e:
exception_log.put(e)

def async_task():
asyncio.run(run_test())

async_task()

if not exception_log.empty():
raise exception_log.get()

0 comments on commit 8c82306

Please sign in to comment.