/
mpi_pool.py
61 lines (46 loc) · 1.48 KB
/
mpi_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Engines with multi-node parallelization."""
import logging
from typing import Any
import cloudpickle as pickle
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
from ..util import tqdm
from .base import Engine
from .task import Task
logger = logging.getLogger(__name__)
def work(pickled_task):
"""Unpickle and execute task."""
task = pickle.loads(pickled_task)
return task.execute()
class MPIPoolEngine(Engine):
"""
Parallelize the task execution.
Uses `mpi4py <https://mpi4py.readthedocs.io/en/stable/>`_.
To be called with:
``mpiexec -np #Workers+1 python -m mpi4py.futures YOURFILE.py``
"""
def __init__(self):
super().__init__()
def execute(
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""
Pickle tasks and distribute work to workers.
Parameters
----------
tasks:
List of :class:`pypesto.engine.Task` to execute.
progress_bar:
Whether to display a progress bar.
Returns
-------
A list of results.
"""
pickled_tasks = [pickle.dumps(task) for task in tasks]
n_procs = MPI.COMM_WORLD.Get_size() # Size of communicator
logger.info(f"Parallelizing on {n_procs-1} workers with one manager.")
with MPIPoolExecutor() as executor:
results = executor.map(
work, tqdm(pickled_tasks, enable=progress_bar)
)
return results