Skip to content
Permalink
Browse files

Merge pull request #1786 from PrefectHQ/better-timeouts

Better timeouts
  • Loading branch information...
cicdw committed Dec 2, 2019
2 parents df67840 + 7648d1d commit 0e24b500cf8dbbaa038ce6c48a37ea2e124661aa
Showing with 213 additions and 7 deletions.
  1. +1 −0 CHANGELOG.md
  2. +114 −1 src/prefect/utilities/executors.py
  3. +31 −0 tests/core/test_flow.py
  4. +37 −3 tests/engine/test_task_runner.py
  5. +30 −3 tests/utilities/test_executors.py
@@ -19,6 +19,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/
### Fixes

- Fix issue with `flow.visualize()` for mapped tasks which are skipped - [#1765](https://github.com/PrefectHQ/prefect/issues/1765)
- Fix issue with timeouts only being softly enforced - [#1145](https://github.com/PrefectHQ/prefect/issues/1145), [#1686](https://github.com/PrefectHQ/prefect/issues/1686)
- Fix issue with `flow.update()` not transferring constants - [#1785](https://github.com/PrefectHQ/prefect/pull/1785)

### Deprecations
@@ -1,8 +1,11 @@
import datetime
import multiprocessing
import signal
import sys
import threading
import time
import warnings

from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FutureTimeout
from functools import wraps
@@ -136,12 +139,98 @@ def inner(
return inner


def main_thread_timeout(
fn: Callable, *args: Any, timeout: int = None, **kwargs: Any
) -> Any:
"""
Helper function for implementing timeouts on function executions.
Implemented by setting a `signal` alarm on a timer. Must be run in the main thread.
Args:
- fn (callable): the function to execute
- *args (Any): arguments to pass to the function
- timeout (int): the length of time to allow for
execution before raising a `TimeoutError`, represented as an integer in seconds
- **kwargs (Any): keyword arguments to pass to the function
Returns:
- the result of `f(*args, **kwargs)`
Raises:
- TimeoutError: if function execution exceeds the allowed timeout
- ValueError: if run from outside the main thread
"""

if timeout is None:
return fn(*args, **kwargs)

def error_handler(signum, frame): # type: ignore
raise TimeoutError("Execution timed out.")

try:
signal.signal(signal.SIGALRM, error_handler)
signal.alarm(timeout)
return fn(*args, **kwargs)
finally:
signal.alarm(0)


def multiprocessing_timeout(
fn: Callable, *args: Any, timeout: int = None, **kwargs: Any
) -> Any:
"""
Helper function for implementing timeouts on function executions.
Implemented by spawning a new multiprocess.Process() and joining with timeout.
Args:
- fn (callable): the function to execute
- *args (Any): arguments to pass to the function
- timeout (int): the length of time to allow for
execution before raising a `TimeoutError`, represented as an integer in seconds
- **kwargs (Any): keyword arguments to pass to the function
Returns:
- the result of `f(*args, **kwargs)`
Raises:
- AssertionError: if run from a daemonic process
- TimeoutError: if function execution exceeds the allowed timeout
"""

if timeout is None:
return fn(*args, **kwargs)

def retrieve_value(
*args: Any, _container: multiprocessing.Queue, _ctx_dict: dict, **kwargs: Any
) -> None:
"""Puts the return value in a multiprocessing-safe container"""
try:
with prefect.context(_ctx_dict):
val = fn(*args, **kwargs)
_container.put(val)
except Exception as exc:
_container.put(exc)

q = multiprocessing.Queue() # type: multiprocessing.Queue
kwargs["_container"] = q
kwargs["_ctx_dict"] = prefect.context.to_dict()
p = multiprocessing.Process(target=retrieve_value, args=args, kwargs=kwargs)
p.start()
p.join(timeout)
p.terminate()
if not q.empty():
res = q.get()
if isinstance(res, Exception):
raise res
return res
else:
raise TimeoutError("Execution timed out.")


def timeout_handler(
fn: Callable, *args: Any, timeout: int = None, **kwargs: Any
) -> Any:
"""
Helper function for implementing timeouts on function executions.
Implemented via `concurrent.futures.ThreadPoolExecutor`.
The exact implementation varies depending on whether this function is being run
in the main thread or a non-daemonic subprocess. If this is run from a daemonic subprocess or on Windows,
the task is run in a `ThreadPoolExecutor` and only a soft timeout is enforced, meaning
a `TimeoutError` is raised at the appropriate time but the task continues running in the background.
Args:
- fn (callable): the function to execute
@@ -156,9 +245,33 @@ def timeout_handler(
Raises:
- TimeoutError: if function execution exceeds the allowed timeout
"""
# if no timeout, just run the function
if timeout is None:
return fn(*args, **kwargs)

# if we are running the main thread, use a signal to stop execution at the appropriate time;
# else if we are running in a non-daemonic process, spawn a subprocess to kill at the appropriate time
if not sys.platform.startswith("win"):
if threading.current_thread() is threading.main_thread():
return main_thread_timeout(fn, *args, timeout=timeout, **kwargs)
elif multiprocessing.current_process().daemon is False:
return multiprocessing_timeout(fn, *args, timeout=timeout, **kwargs)

msg = (
"This task is running in a daemonic subprocess; "
"consequently Prefect can only enforce a soft timeout limit, i.e., "
"if your Task reaches its timeout limit it will enter a TimedOut state "
"but continue running in the background."
)
else:
msg = (
"This task is running on Windows; "
"consequently Prefect can only enforce a soft timeout limit, i.e., "
"if your Task reaches its timeout limit it will enter a TimedOut state "
"but continue running in the background."
)

warnings.warn(msg)
executor = ThreadPoolExecutor()

def run_with_ctx(*args: Any, _ctx_dict: dict, **kwargs: Any) -> Any:
@@ -4,6 +4,7 @@
import random
import sys
import tempfile
import time
import uuid
from unittest.mock import MagicMock, patch

@@ -2505,3 +2506,33 @@ def do_nothing(arg):

flow_state = flow.run()
assert flow_state.is_successful()


@pytest.mark.skipif(
sys.platform == "win32", reason="Windows doesn't support any timeout logic"
)
@pytest.mark.parametrize("executor", ["local", "sync", "mthread"], indirect=True)
def test_timeout_actually_stops_execution(executor):
with tempfile.TemporaryDirectory() as call_dir:
FILE = os.path.join(call_dir, "test.txt")

@prefect.task(timeout=1)
def slow_fn():
"Runs for 1.5 seconds, writes to file 7 times"
iters = 0
while iters < 6:
time.sleep(0.25)
with open(FILE, "a") as f:
f.write("called\n")
iters += 1

flow = Flow("timeouts", tasks=[slow_fn])
state = flow.run(executor=executor)

# if it continued running, would run for 1 more second
time.sleep(0.5)
with open(FILE, "r") as g:
contents = g.read()

assert len(contents.split("\n")) <= 4
assert state.is_failed()
@@ -1,11 +1,14 @@
import collections
import os
import pendulum
import pytest
import sys
import tempfile

from datetime import datetime, timedelta
from time import sleep
from unittest.mock import MagicMock

import pendulum
import pytest

import prefect
from prefect.client import Secret
from prefect.core.edge import Edge
@@ -298,6 +301,37 @@ def test_task_runner_accepts_dictionary_of_edges():
assert state.result == 2


@pytest.mark.skipif(
sys.platform == "win32", reason="Windows doesn't support any timeout logic"
)
@pytest.mark.parametrize(
"executor", ["local", "sync", "mproc", "mthread"], indirect=True
)
def test_timeout_actually_stops_execution(executor):
with tempfile.TemporaryDirectory() as call_dir:
FILE = os.path.join(call_dir, "test.txt")

@prefect.task(timeout=1)
def slow_fn():
"Runs for 1.5 seconds, writes to file 6 times"
iters = 0
while iters < 6:
sleep(0.25)
with open(FILE, "a") as f:
f.write("called\n")
iters += 1

state = TaskRunner(slow_fn).run(executor=executor)

# if it continued running, would run for 1 more second
sleep(0.5)
with open(FILE, "r") as g:
contents = g.read()

assert len(contents.split("\n")) <= 4
assert state.is_failed()


def test_task_runner_can_handle_timeouts_by_default():
sleeper = SlowTask(timeout=1)
upstream_state = Success(result=2)
@@ -1,6 +1,8 @@
import os
import multiprocessing
import sys
import threading
import tempfile
import time
from datetime import timedelta
from unittest.mock import MagicMock
@@ -127,6 +129,33 @@ def test_timeout_handler_times_out():
timeout_handler(slow_fn, timeout=1)


@pytest.mark.skipif(
sys.platform == "win32", reason="Windows doesn't support any timeout logic"
)
def test_timeout_handler_actually_stops_execution():
with tempfile.TemporaryDirectory() as call_dir:
FILE = os.path.join(call_dir, "test.txt")

def slow_fn():
"Runs for 1.5 seconds, writes to file 6 times"
iters = 0
while iters < 6:
time.sleep(0.26)
with open(FILE, "a") as f:
f.write("called\n")
iters += 1

with pytest.raises(TimeoutError):
# allow for at most 3 writes
timeout_handler(slow_fn, timeout=1)

time.sleep(0.5)
with open(FILE, "r") as g:
contents = g.read()

assert len(contents.split("\n")) <= 4


def test_timeout_handler_passes_args_and_kwargs_and_returns():
def do_nothing(x, y=None):
return x, y
@@ -178,9 +207,7 @@ def my_thread():


def test_timeout_handler_doesnt_do_anything_if_no_timeout(monkeypatch):
monkeypatch.delattr(prefect.utilities.executors, "ThreadPoolExecutor")
with pytest.raises(NameError): # to test the test's usefulness...
timeout_handler(lambda: 4, timeout=1)
assert timeout_handler(lambda: 4, timeout=1) == 4
assert timeout_handler(lambda: 4) == 4


0 comments on commit 0e24b50

Please sign in to comment.
You can’t perform that action at this time.