Skip to content

Commit

Permalink
Make sure subprocess/process/thread/dask can create a runner instance (
Browse files Browse the repository at this point in the history
  • Loading branch information
Delaunay committed Mar 3, 2022
1 parent f0c3af5 commit 6a3f9be
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 15 deletions.
19 changes: 17 additions & 2 deletions src/orion/client/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,24 @@ def __init__(self):
self.handlers = dict()
self.start = 0
self.delayed = 0
self.signal_installed = False

def __enter__(self):
"""Override the signal handlers with our delayed handler"""
self.signal_received = False
self.handlers[signal.SIGINT] = signal.signal(signal.SIGINT, self.handler)
self.handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, self.handler)

try:
self.handlers[signal.SIGINT] = signal.signal(signal.SIGINT, self.handler)
self.handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, self.handler)
self.signal_installed = True

except ValueError: # ValueError: signal only works in main thread
log.warning(
"SIGINT/SIGTERM protection hooks could not be installed because "
"Runner is executing inside a thread/subprocess, results could get lost "
"on interruptions"
)

return self

def handler(self, sig, frame):
Expand All @@ -65,6 +77,9 @@ def handler(self, sig, frame):

def restore_handlers(self):
"""Restore old signal handlers"""
if not self.signal_installed:
return

signal.signal(signal.SIGINT, self.handlers[signal.SIGINT])
signal.signal(signal.SIGTERM, self.handlers[signal.SIGTERM])

Expand Down
10 changes: 10 additions & 0 deletions src/orion/executor/dask_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ def successful(self):


class Dask(BaseExecutor):
"""Wrapper around the dask client.
.. warning::
The Dask executor can be pickled and used inside a subprocess,
the pickled client will use the main client that was spawned in the main process,
but you cannot spawn clients inside a subprocess.
"""

def __init__(self, n_workers=-1, client=None, **config):
super(Dask, self).__init__(n_workers=n_workers)

Expand Down
17 changes: 10 additions & 7 deletions src/orion/executor/multiprocess_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ class PoolExecutor(BaseExecutor):
backend: str
Pool backend to use; thread or multiprocess, defaults to multiprocess
.. warning::
Pickling of the executor is not supported, see Dask for a backend that supports it
"""

BACKENDS = dict(
Expand All @@ -173,6 +177,12 @@ def __init__(self, n_workers=-1, backend="multiprocess", **kwargs):

self.pool = PoolExecutor.BACKENDS.get(backend, ThreadPool)(n_workers)

def __setstate__(self, state):
self.pool = state["pool"]

def __getstate__(self):
return dict(pool=self.pool)

def __enter__(self):
return self

Expand All @@ -188,13 +198,6 @@ def close(self):
if hasattr(self, "pool"):
self.pool.shutdown()

def __getstate__(self):
state = super(PoolExecutor, self).__getstate__()
return state

def __setstate__(self, state):
super(PoolExecutor, self).__setstate__(state)

def submit(self, function, *args, **kwargs):
try:
return self._submit_cloudpickle(function, *args, **kwargs)
Expand Down
91 changes: 91 additions & 0 deletions tests/unittests/client/runner_subprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Used to test instantiating a runner inside a subprocess"""
from argparse import ArgumentParser

from orion.client.runner import Runner
from orion.core.utils.exceptions import WaitingForTrials
from orion.core.worker.trial import Trial
from orion.executor.base import executor_factory

idle_timeout = 20
count = 10
n_workers = 2


parser = ArgumentParser()
parser.add_argument("--backend", type=str, default="joblib")
args = parser.parse_args()


def new_trial(value, sleep=0.01):
"""Generate a dummy new trial"""
return Trial(
params=[
dict(name="lhs", type="real", value=value),
dict(name="sleep", type="real", value=sleep),
]
)


class FakeClient:
"""Orion mock client for Runner."""

def __init__(self, n_workers):
self.is_done = False
self.executor = executor_factory.create(args.backend, n_workers)
self.suggest_error = WaitingForTrials
self.trials = []
self.status = []
self.working_dir = ""

def suggest(self, pool_size=None):
"""Fake suggest."""
if self.trials:
return self.trials.pop()

raise self.suggest_error

def release(self, trial, status=None):
"""Fake release."""
self.status.append(status)

def observe(self, trial, value):
"""Fake observe"""
self.status.append("completed")

def close(self):
self._free_executor()

def __del__(self):
self._free_executor()

def _free_executor(self):
if self.executor is not None:
self.executor.__exit__(None, None, None)
self.executor = None
self.executor_owner = False


def function(lhs, sleep):
return lhs + sleep


client = FakeClient(n_workers)

runner = Runner(
client=client,
fct=function,
pool_size=10,
idle_timeout=idle_timeout,
max_broken=2,
max_trials_per_worker=2,
trial_arg=[],
on_error=None,
)

client = runner.client

client.trials.extend([new_trial(i) for i in range(count)])

runner.run()
runner.client.close()
print("done")
140 changes: 134 additions & 6 deletions tests/unittests/client/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import copy
import os
import signal
import sys
import time
import traceback
from contextlib import contextmanager
from multiprocessing import Process
from multiprocessing import Process, Queue
from threading import Thread
from wsgiref.simple_server import sys_version

import pytest

Expand All @@ -21,8 +24,12 @@
)
from orion.core.worker.trial import Trial
from orion.executor.base import executor_factory
from orion.executor.dask_backend import HAS_DASK, Dask
from orion.storage.base import LockAcquisitionTimeout
from orion.testing import create_experiment


def compatible(version):
return sys.version_info.major == version[0] and sys.version_info.minor >= version[1]


def new_trial(value, sleep=0.01):
Expand All @@ -47,9 +54,14 @@ def change_signal_handler(sig, handler):
class FakeClient:
"""Orion mock client for Runner."""

def __init__(self, n_workers):
def __init__(self, n_workers, backend="joblib", executor=None):
self.is_done = False
self.executor = executor_factory.create("joblib", n_workers)

if executor is None:
self.executor = executor_factory.create(backend, n_workers)
else:
self.executor = executor

self.suggest_error = WaitingForTrials
self.trials = []
self.status = []
Expand Down Expand Up @@ -100,10 +112,10 @@ def function(lhs, sleep):
return lhs + sleep


def new_runner(idle_timeout, n_workers=2, client=None):
def new_runner(idle_timeout, n_workers=2, client=None, executor=None, backend="joblib"):
"""Create a new runner with a mock client."""
if client is None:
client = FakeClient(n_workers)
client = FakeClient(n_workers, backend=backend, executor=executor)

runner = Runner(
client=client,
Expand Down Expand Up @@ -535,3 +547,119 @@ def make_runner(n_workers, max_trials_per_worker, pool_size=None):
runner.trials = 5
assert runner.should_sample() == 0, "The max number of trials was reached"
runner.client.close()


def run_runner(reraise=False, executor=None):
try:
count = 10
max_trials = 10
workers = 2

runner = new_runner(0.1, n_workers=workers, executor=executor)
runner.max_trials_per_worker = max_trials
client = runner.client

client.trials.extend([new_trial(i, sleep=0) for i in range(count)])

if executor is None:
executor = client.executor

def set_is_done():
time.sleep(0.05)
runner.pending_trials = dict()
runner.client.is_done = True

start = time.time()
thread = Thread(target=set_is_done)
thread.start()

with executor:
runner.run()

print("done")
return 0
except:
if reraise:
raise

traceback.print_exc()
return 1


def test_runner_inside_process():
"""Runner can execute inside a process"""

queue = Queue()

def get_result(results):
results.put(run_runner())

p = Process(target=get_result, args=(queue,))
p.start()
p.join()

assert queue.get() == 0
assert p.exitcode == 0


def test_runner_inside_childprocess():
"""Runner can execute inside a child process"""
pid = os.fork()

# execute runner in the child process
if pid == 0:
run_runner()
os._exit(0)
else:
# parent process wait for child process to end
wpid, exit_status = os.wait()
assert wpid == pid
assert exit_status == 0


def test_runner_inside_subprocess():
"""Runner can execute inside a subprocess"""

import subprocess

dir = os.path.dirname(__file__)

result = subprocess.run(
["python", f"{dir}/runner_subprocess.py", "--backend", "joblib"],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)

assert result.stderr.decode("utf-8") == ""
assert result.stdout.decode("utf-8") == "done\n"
assert result.returncode == 0


def test_runner_inside_thread():
"""Runner can execute inside a thread"""

class GetResult:
def __init__(self) -> None:
self.r = None

def run(self):
self.r = run_runner()

result = GetResult()
thread = Thread(target=result.run)
thread.start()
thread.join()

assert result.r == 0


@pytest.mark.skipif(not HAS_DASK, reason="Running without dask")
def test_runner_inside_dask():
"""Runner can not execute inside a dask worker"""

executor = Dask()

future = executor.submit(run_runner, executor=executor, reraise=True)

assert future.get() == 0
Loading

0 comments on commit 6a3f9be

Please sign in to comment.