Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to spawn workers inside daemon #1067

Open
wants to merge 21 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/orion/executor/multiprocess_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@
if v.major == 3 and v.minor >= 8:
args = args[1:]

if Pool.ALLOW_DAEMON:
return Process(*args, **kwds)
if not Pool.ALLOW_DAEMON:
return PyPool.Process(*args, **kwds)

Check warning on line 87 in src/orion/executor/multiprocess_backend.py

View check run for this annotation

Codecov / codecov/patch

src/orion/executor/multiprocess_backend.py#L87

Added line #L87 was not covered by tests

return _Process(*args, **kwds)
return _Process(*args, **kwds, daemon=False)

def shutdown(self):
# NB: https://pytest-cov.readthedocs.io/en/latest/subprocess-support.html
Expand Down Expand Up @@ -167,13 +167,18 @@
if n_workers <= 0:
n_workers = multiprocessing.cpu_count()

self.pool_config = {"n_workers": n_workers, "backend": backend}
self.pool = PoolExecutor.BACKENDS.get(backend, ThreadPool)(n_workers)

def __setstate__(self, state):
self.pool = state["pool"]
log.warning("Nesting multiprocess executor")
bouthilx marked this conversation as resolved.
Show resolved Hide resolved
self.pool_config = state["pool_config"]
backend = self.pool_config.get("backend")
n_workers = self.pool_config.get("n_workers", -1)
self.pool = PoolExecutor.BACKENDS.get(backend, ThreadPool)(n_workers)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit unsure about this part. If the object is serialized and passed to the subprocess, the deserialization step will have the effect or creating another pool of n_workers, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we maybe able to pass a queue instead to avoid creating multiple pools
but nesting the executor in general is a bit of a nono


def __getstate__(self):
return dict(pool=self.pool)
return {"pool_config": self.pool_config}

def __enter__(self):
return self
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/client/test_experiment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def main(*args, **kwargs):


def test_run_experiment_twice():
""""""
"""Makes sure the executor is not freed after workon"""

with create_experiment(config, base_trial) as (cfg, experiment, client):
client.workon(main, max_trials=10)
Expand Down
84 changes: 75 additions & 9 deletions tests/unittests/executor/test_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import multiprocessing
import multiprocessing.process as proc
import os
import time

Expand All @@ -9,6 +11,14 @@
from orion.executor.ray_backend import HAS_RAY, Ray
from orion.executor.single_backend import SingleExecutor

try:
import torch
from torchvision import datasets, transforms

HAS_PYTORCH = True
except:
HAS_PYTORCH = False


def multiprocess(n):
return PoolExecutor(n, "multiprocess")
Expand Down Expand Up @@ -265,7 +275,7 @@ def nested(executor):
return sum(f.get() for f in futures)


@pytest.mark.parametrize("backend", [xfail_dask_if_not_installed(Dask), SingleExecutor])
@pytest.mark.parametrize("backend", backends)
def test_nested_submit(backend):
with backend(5) as executor:
futures = [executor.submit(nested, executor) for i in range(5)]
Expand All @@ -276,17 +286,36 @@ def test_nested_submit(backend):
assert r.value == 35


@pytest.mark.parametrize("backend", [multiprocess, thread])
def test_nested_submit_failure(backend):
def inc(a):
return a + 1


def nested_pool():
import multiprocessing.process as proc

assert not proc._current_process._config.get("daemon")

data = [1, 2, 3, 4, 5, 6]
with multiprocessing.Pool(5) as p:
result = p.map_async(inc, data)
result.wait()
data = result.get()

return sum(data)


@pytest.mark.parametrize("backend", backends)
def test_nested_submit_pool(backend):
if backend is Dask:
pytest.xfail("Dask does not support nesting")

with backend(5) as executor:
futures = [executor.submit(nested_pool) for i in range(5)]

if backend == multiprocess:
exception = NotImplementedError
elif backend == thread:
exception = TypeError
results = executor.async_get(futures, timeout=2)

with pytest.raises(exception):
[executor.submit(nested, executor) for i in range(5)]
for r in results:
assert r.value == 27


@pytest.mark.parametrize("executor", executors)
Expand All @@ -310,3 +339,40 @@ def test_executors_del_does_not_raise(backend):
del executor.client

del executor


def pytorch_workon(pid):
assert not proc._current_process._config.get("daemon")

transform = transforms.Compose(
[
transforms.ToTensor(),
]
)

dataset = datasets.FakeData(128, transform=transform)

loader = torch.utils.data.DataLoader(dataset, num_workers=2, batch_size=64)

for i, _ in enumerate(loader):
pass

return i


@pytest.mark.parametrize("backend", backends)
def test_pytorch_dataloader(backend):
if backend is Dask:
pytest.xfail("Dask does not support nesting")

if not HAS_PYTORCH:
pytest.skip("Pytorch is not installed skipping")
return

with backend(2) as executor:
futures = [executor.submit(pytorch_workon, i) for i in range(2)]

results = executor.async_get(futures, timeout=2)

for r in results:
assert r.value == 1
Loading