Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MPI-capable tasks #9

Merged
merged 61 commits into from
May 29, 2019
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
120be38
Add proof of concept for tasks with MPI capabilities (in the same nam…
May 20, 2019
e9ca2bc
Tidy up a bit
May 20, 2019
e22b42b
Carfeul about (ab)using globals
May 20, 2019
091fd1e
Make script executable
May 20, 2019
6c2b8b7
Fix spelling error
May 20, 2019
8fa6bce
Merge branch 'mpi_enabled_tasks' of github.com:E-CAM/jobqueue_feature…
May 20, 2019
78bdffd
Merge branch 'master' into mpi_enabled_tasks
May 20, 2019
978176a
Start incorporating the necessary code
May 20, 2019
b00e9c0
Only import MPI if I need to
May 20, 2019
5c2715b
Add some error checking
May 21, 2019
356ea7e
Move some more functions to the module, improve errors
May 21, 2019
e147ba4
Add the wrapper script for dask_worker
May 21, 2019
af3a043
Simplify
May 21, 2019
7939481
Prepare for wrapping mpi_dask_worker.py call
May 21, 2019
7046a87
Add dask mpi launcher
May 21, 2019
4745140
Run black
May 21, 2019
11fcb70
Provide traceback when aborting
May 21, 2019
4b1c8fa
Enable switch for forking MPI programs
May 22, 2019
6317d56
Run black on codebase
May 22, 2019
d0eebce
Add default MPI launcher to tests
May 22, 2019
7ac93a4
Restore state in tests
May 22, 2019
628e94b
Prepare for more tests
May 22, 2019
6a14b26
Correct warning string
May 22, 2019
aec9be0
Correct warning string
May 23, 2019
f7b4fa9
Correct warning string
May 23, 2019
145556c
Imports done incorrectly
May 23, 2019
4b5d6d3
Add an example to test MPI functionality
May 23, 2019
d6f81a8
Add launcher to example
May 23, 2019
0e2e83d
Use trivial queue
May 23, 2019
e372d0c
Fix typo
May 23, 2019
0b5ea90
Only add additional kwargs if they are needed
May 23, 2019
7597b80
No forking other processes when using MPI to launch dask
May 23, 2019
09d5cf8
Add missing kwargs
May 23, 2019
1b9d8b8
Use simple queues in example
May 23, 2019
4552e29
Move requirements in CI
May 23, 2019
e2f0aaf
Only do the dask worker import where it is needed
May 24, 2019
8c830cd
Merge remote-tracking branch 'origin/mpi_enabled_tasks' into mpi_enab…
May 24, 2019
1ce44a6
Debugging
May 24, 2019
14f5e5d
Debugging
May 24, 2019
6a7168a
Serialize before submitting the task
May 24, 2019
0bf2dd1
Serialize before submitting the task
May 24, 2019
73450af
Serialize before submitting the task
May 24, 2019
af29b1f
Be more careful with args and kwargs
May 24, 2019
85be2ba
Get rid of a warning
May 24, 2019
040fd7a
Remove unnecessary print statements
May 24, 2019
55b3da0
Make sure we can run multiple tasks
May 24, 2019
60048dc
Make sure we can run multiple tasks
May 24, 2019
f3543a7
Make sure we can run multiple tasks
May 24, 2019
84a5d25
Flush strings
May 24, 2019
32a7e86
REmove bad print command
May 24, 2019
7bf1799
See if nanny works
May 24, 2019
a6c171e
Only do import when we need to
May 24, 2019
1788b86
Only do import when we need to
May 24, 2019
707a4f0
Revert last commit
May 24, 2019
db6f089
Address comment
May 29, 2019
584c33e
Tidy up before adding additional tests
May 29, 2019
4f0aace
Remove dask worker directories
May 29, 2019
e59b5a7
Add initial set of new tests
May 29, 2019
d48d54e
Additional tests
May 29, 2019
b586dc9
Final additional tests
May 29, 2019
6d3c6ea
Make sure Py2 is also passing
May 29, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 34 additions & 0 deletions jobqueue_features/mpi_dask_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python
"""
Distribution of MPI enabled tasks
"""

from jobqueue_features.mpi_wrapper import (
mpi_deserialize_and_execute,
serialize_function_and_args,
shutdown_mpitask_worker,
)
from distributed.cli.dask_worker import go
from mpi4py import MPI


def prepare_for_mpi_tasks():
comm = MPI.COMM_WORLD
rank = comm.Get_rank()

if rank == 0:
# Start dask so root reports to scheduler and accepts tasks
# Task distribution is part of task itself (via our wrapper)
go()

# As a final task, send a shutdown to the other MPI ranks
serialized_object = serialize_function_and_args(shutdown_mpitask_worker)
mpi_deserialize_and_execute(serialized_object=serialized_object)
else:
while True:
# Calling with no arguments means these are non-root processes
mpi_deserialize_and_execute()


if __name__ == "__main__":
prepare_for_mpi_tasks()
133 changes: 123 additions & 10 deletions jobqueue_features/mpi_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from distributed.protocol import serialize, deserialize
import shlex
import subprocess
import sys
from typing import Dict # noqa


Expand All @@ -19,6 +21,7 @@ def mpi_wrap(
cpus_per_task=None,
ntasks_per_node=None,
exec_args="",
return_wrapped_command=False,
**kwargs
):
# type: (str, str, str, str, str, ...) -> Dict[str, str]
Expand Down Expand Up @@ -79,16 +82,126 @@ def get_default_mpi_params(
exec_args=exec_args,
)
)
try:
proc = subprocess.Popen(
shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE
if return_wrapped_command:
result = cmd
else:
try:
proc = subprocess.Popen(
shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
out = proc.stdout.read()
err = proc.stderr.read()
except OSError as err:
raise OSError(
"OS error caused by constructed command: {cmd}\n\n{err}".format(
cmd=cmd, err=err
)
)
result = {"cmd": cmd, "out": out, "err": err}

return result


def shutdown_mpitask_worker():
from mpi4py import MPI

# Finalise MPI
MPI.Finalize()
# and then exit
exit()


def deserialize_and_execute(serialized_object):
# Ensure the serialized object is of the expected type
if isinstance(serialized_object, dict):
# Make sure it has the expected entries
if not ("header" in serialized_object and "frames" in serialized_object):
raise RuntimeError(
"serialized_object dict does not have expected keys [header, frames]"
)
else:
raise RuntimeError("Cannot deserialize without a serialized_object")
func = deserialize(serialized_object["header"], serialized_object["frames"])
if serialized_object.get("args_header"):
args = deserialize(
serialized_object["args_header"], serialized_object["args_frames"]
)
out = proc.stdout.read()
err = proc.stderr.read()
except OSError as err:
raise OSError(
"OS error caused by constructed command: {cmd}\n\n{err}".format(
cmd=cmd, err=err
else:
args = []
if serialized_object.get("kwargs_header"):
kwargs = deserialize(
serialized_object["kwargs_header"], serialized_object["kwargs_frames"]
)
else:
kwargs = {}

# Free memory space used by (potentially large) serialised object
del serialized_object

# Execute the function and return
return func(*args, **kwargs)


def flush_and_abort(msg="Flushing print buffer and aborting", comm=None, error_code=1):
from mpi4py import MPI

if comm is None:
comm = MPI.COMM_WORLD
print(msg)
sys.stdout.flush()
if error_code == 0:
print("To abort correctly, we need to use a non-zero error code")
error_code = 1
comm.Abort(error_code)


def mpi_deserialize_and_execute(serialized_object=None, root=0, comm=None):
from mpi4py import MPI

if comm is None:
comm = MPI.COMM_WORLD

# We only handle the case where root has the object and is the one who returns
# something
if serialized_object:
# Check we have a valid communicator
try:
rank = comm.Get_rank()
except AttributeError:
flush_and_abort(
msg="Looks like you did not pass a valid MPI communicator, aborting "
"using global communicator"
)
sys.stdout.flush()
MPI.COMM_WORLD.Abort(1)
if rank != root:
flush_and_abort(
msg="Only root rank (%d) can contain a serialized object for this "
"call, my rank is %d...aborting!" % (root, rank),
comm=comm,
)
return_something = True
else:
return_something = False
serialized_object = comm.bcast(serialized_object, root=root)
result = deserialize_and_execute(serialized_object)

if return_something and result:
return result


def serialize_function_and_args(func, *args, **kwargs):
header, frames = serialize(func)
serialized_object = {"header": header, "frames": frames}
if args:
args_header, args_frames = serialize(args)
serialized_object.update(
{"args_header": args_header, "args_frames": args_frames}
)
if kwargs:
kwargs_header, kwargs_frames = serialize(kwargs)
serialized_object.update(
{"kwargs_header": kwargs_header, "kwargs_frames": kwargs_frames}
)
return {"cmd": cmd, "out": out, "err": err}

return serialized_object
53 changes: 53 additions & 0 deletions mpi_pickling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python
ocaisa marked this conversation as resolved.
Show resolved Hide resolved
"""
Distribution of MPI enabled tasks
"""

from jobqueue_features.mpi_wrapper import (
mpi_deserialize_and_execute,
serialize_function_and_args,
shutdown_mpitask_worker,
)
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

if rank == 0:
# This is the task, which is only defined on root
def task1(task_name):
from mpi4py import MPI

comm = MPI.COMM_WORLD
size = comm.Get_size()
name = MPI.Get_processor_name()
all_nodes = comm.gather(name, root=0)
if all_nodes:
all_nodes = set(all_nodes)
else:
all_nodes = []
# Since it is a return value it will only get printed by root
return "Running %d tasks of type %s on nodes %s." % (size, task_name, all_nodes)

def task2(name, task_name="default"):
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
print("Hi %s, my rank is %d for task of type %s" % (name, rank, task_name))

serialized_object = serialize_function_and_args(task1, "task1")
result = mpi_deserialize_and_execute(serialized_object=serialized_object)
if result:
print(result)
serialized_object = serialize_function_and_args(task2, "alan", task_name="task2")
mpi_deserialize_and_execute(serialized_object=serialized_object)

# As a final task, send a shutdown to the other MPI ranks
serialized_object = serialize_function_and_args(shutdown_mpitask_worker)
mpi_deserialize_and_execute(serialized_object=serialized_object)

else:
while True:
# Calling with no arguments means these are non-root processes
mpi_deserialize_and_execute()