-
Notifications
You must be signed in to change notification settings - Fork 240
/
proc_allocator.py
143 lines (120 loc) · 4.42 KB
/
proc_allocator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Define the base ProcAllocator class."""
import numpy as np
from openmdao.core.constants import INT_DTYPE
class ProcAllocationError(Exception):
"""
Exception containing subsystem index information for use at higher levels.
Parameters
----------
msg : str
The message string.
sub_inds : list of int
Indices of subsystems in _subsystems_allprocs in parent.
Attributes
----------
msg : str
The message string.
sub_inds : list of int
Indices of subsystems in _subsystems_allprocs in parent.
"""
def __init__(self, msg, sub_inds=None):
"""
Initialize all attributes.
"""
super().__init__(msg)
self.msg = msg
self.sub_inds = sub_inds
class ProcAllocator(object):
"""
Algorithm for allocating processors to a given system's subsystems.
Parameters
----------
parallel : bool
If True, split subsystem comm.
Attributes
----------
parallel : bool
True means the comm is split across subsystems;
False means the comm is passed to all subsystems.
"""
def __init__(self, parallel=False):
"""
Initialize all attributes.
"""
self.parallel = parallel
def __call__(self, proc_info, nsubs, comm):
"""
Perform the allocation if parallel.
Parameters
----------
proc_info : list of (min_procs, max_procs, weight)
Information used to determine MPI process allocation to subsystems.
nsubs : int
Number of subsystems.
comm : MPI.Comm or <FakeComm>
communicator of the owning system.
Returns
-------
isubs : [int, ...]
indices of the owned local subsystems.
sub_comm : MPI.Comm or <FakeComm>
communicator to pass to the subsystems.
sub_proc_range : (int, int)
The range of processors that the subcomm owns, among those of comm.
"""
if self.parallel and comm.size > 1:
# This is a parallel group
return self._divide_procs(proc_info, comm)
else:
nproc = comm.size
min_procs, max_procs, _ = self._split_proc_info(proc_info, comm)
if np.any(max_procs < nproc):
raise ProcAllocationError("too many MPI procs allocated (%d)" % nproc,
np.array(list(range(nsubs)))[max_procs < nproc])
if np.any(min_procs > nproc):
raise ProcAllocationError("can't meet min_procs required",
np.array(list(range(nsubs)))[min_procs > nproc])
# This is a serial group - all procs get all subsystems
return list(range(nsubs)), comm, (0, comm.size)
def _split_proc_info(self, proc_info, comm):
"""
Split proc_info into min_procs, max_procs, and weights.
Parameters
----------
proc_info : list of (min_procs, max_procs, weight)
Information used to determine MPI process allocation to subsystems.
comm : MPI.Comm or <FakeComm>
communicator of the owning system.
Returns
-------
list of int
Min procs required for each subsystem.
list of int
Max procs required for each subsystem.
list of float
Weights for each subsystem.
"""
nproc = comm.size
min_procs = np.array([minp for minp, _, _ in proc_info], dtype=INT_DTYPE)
# if max_procs entry is None or > nproc, it just becomes nproc
max_procs = np.array([nproc if maxp is None or maxp > nproc else
maxp for _, maxp, _ in proc_info], dtype=INT_DTYPE)
weights = np.array([weight for _, _, weight in proc_info])
return min_procs, max_procs, weights
def _divide_procs(self, proc_info, comm):
"""
Perform the parallel processor allocation.
Parameters
----------
proc_info : list of (min_procs, max_procs, weight)
Information used to determine MPI process allocation to subsystems.
comm : MPI.Comm or <FakeComm>
communicator of the owning system.
Returns
-------
isubs : [int, ...]
indices of the owned local subsystems.
sub_comm : MPI.Comm or <FakeComm>
communicator to pass to the subsystems.
"""
pass