Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure subprocess/process/thread/dask can create a runner instance #816

Merged
merged 16 commits into from
Mar 3, 2022
18 changes: 16 additions & 2 deletions src/orion/client/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,23 @@ def __init__(self):
self.handlers = dict()
self.start = 0
self.delayed = 0
self.signal_installed = False

def __enter__(self):
"""Override the signal handlers with our delayed handler"""
self.signal_received = False
self.handlers[signal.SIGINT] = signal.signal(signal.SIGINT, self.handler)
self.handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, self.handler)

try:
self.handlers[signal.SIGINT] = signal.signal(signal.SIGINT, self.handler)
self.handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, self.handler)
self.signal_installed = True
except ValueError:
log.warning(
"SIGINT/SIGTERM protection hooks could not be installed because "
"Runner is executing inside a thread/subprocess, results could get lost "
"on interruptions"
)

return self

def handler(self, sig, frame):
Expand All @@ -65,6 +76,9 @@ def handler(self, sig, frame):

def restore_handlers(self):
"""Restore old signal handlers"""
if not self.signal_installed:
return

signal.signal(signal.SIGINT, self.handlers[signal.SIGINT])
signal.signal(signal.SIGTERM, self.handlers[signal.SIGTERM])

Expand Down
100 changes: 98 additions & 2 deletions tests/unittests/client/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import signal
import time
from contextlib import contextmanager
from multiprocessing import Process
from multiprocessing import Process, Queue
from threading import Thread

import pytest
Expand All @@ -21,8 +21,8 @@
)
from orion.core.worker.trial import Trial
from orion.executor.base import executor_factory
from orion.executor.dask_backend import HAS_DASK, Dask
from orion.storage.base import LockAcquisitionTimeout
from orion.testing import create_experiment


def new_trial(value, sleep=0.01):
Expand Down Expand Up @@ -535,3 +535,99 @@ def make_runner(n_workers, max_trials_per_worker, pool_size=None):
runner.trials = 5
assert runner.should_sample() == 0, "The max number of trials was reached"
runner.client.close()


def run_runner():
try:
count = 10
max_trials = 10
workers = 2

runner = new_runner(0.1, n_workers=workers)
runner.max_trials_per_worker = max_trials
client = runner.client

client.trials.extend([new_trial(i) for i in range(count)])

runner.run()
runner.client.close()
print("done")
return 0
except:
return 1


def test_runner_inside_process():
"""Runner can execute inside a process"""

queue = Queue()

def get_result(results):
results.put(run_runner())

p = Process(target=get_result, args=(queue,))
p.start()
p.join()

assert queue.get() == 0
assert p.exitcode == 0


def test_runner_inside_childprocess():
"""Runner can execute inside a child process"""
pid = os.fork()

# execute runner in the child process
if pid == 0:
run_runner()
os._exit(0)
else:
# parent process wait for child process to end
wpid, exit_status = os.wait()
assert wpid == pid
assert exit_status == 0


def test_runner_inside_subprocess():
"""Runner can execute inside a subprocess"""

import subprocess

dir = os.path.dirname(__file__)

result = subprocess.run(
["python", f"{dir}/runner_subprocess.py"], capture_output=True
)

assert result.stdout.decode("utf-8") == "done\n"
assert result.stderr.decode("utf-8") == ""
assert result.returncode == 0


def test_runner_inside_thread():
"""Runner can execute inside a thread"""

class GetResult:
def __init__(self) -> None:
self.r = None

def run(self):
self.r = run_runner()

result = GetResult()
thread = Thread(target=result.run)
thread.start()
thread.join()

assert result.r == 0


@pytest.mark.skipif(not HAS_DASK, reason="Running without dask")
def test_runner_inside_dask():
"""Runner can execute inside a dask"""

client = Dask()

future = client.submit(run_runner)

assert future.get() == 0