-
Notifications
You must be signed in to change notification settings - Fork 240
/
parallel_group.py
164 lines (140 loc) · 5.31 KB
/
parallel_group.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""Define the ParallelGroup class."""
from openmdao.core.group import Group
from openmdao.utils.om_warnings import issue_warning
class ParallelGroup(Group):
"""
Class used to group systems together to be executed in parallel.
Parameters
----------
**kwargs : dict
Dict of arguments available here and in all descendants of this Group.
"""
def __init__(self, **kwargs):
"""
Set the mpi_proc_allocator option to 'parallel'.
"""
super().__init__(**kwargs)
self._mpi_proc_allocator.parallel = True
def _configure(self):
"""
Configure our model recursively to assign any children settings.
Highest system's settings take precedence.
"""
super()._configure()
if self.comm.size > 1:
self._has_guess = any(self.comm.allgather(self._has_guess))
def _get_sys_promotion_tree(self, tree):
tree = super()._get_sys_promotion_tree(tree)
if self.comm.size > 1:
prefix = self.pathname + '.' if self.pathname else ''
subtree = {n: data for n, data in tree.items() if n.startswith(prefix)}
for sub in self.comm.allgather(subtree): # TODO: make this more efficient
for n, data in sub.items():
if n not in tree:
tree[n] = data
return tree
def _ordered_comp_name_iter(self):
"""
Yield contained component pathnames in order of execution.
For components within ParallelGroups, true execution order is unknown so components
will be ordered by rank within a ParallelGroup.
"""
if self.comm.size > 1:
names = []
for s in self._subsystems_myproc:
if isinstance(s, Group):
names.extend(s._ordered_comp_name_iter())
else:
names.append(s.pathname)
seen = set()
for ranknames in self.comm.allgather(names):
for name in ranknames:
if name not in seen:
yield name
seen.add(name)
else:
yield from super()._ordered_comp_name_iter()
def _check_order(self, reorder=True, recurse=True, out_of_order=None):
"""
Check if auto ordering is needed and if so, set the order appropriately.
Parameters
----------
reorder : bool
If True, reorder the subsystems based on the new order. Otherwise
just return the out-of-order connections.
recurse : bool
If True, call this method on all subgroups.
out_of_order : dict
Lists of out-of-order connections keyed by group pathname.
Returns
-------
dict
Lists of out-of-order connections keyed by group pathname.
"""
if self.options['auto_order']:
issue_warning("auto_order is not supported in ParallelGroup. "
"Ignoring auto_order option.", prefix=self.msginfo)
if out_of_order is None:
out_of_order = {}
if recurse:
for s in self._subgroups_myproc:
s._check_order(reorder, recurse, out_of_order)
return out_of_order
def comm_info_iter(self):
"""
Yield comm size and rank for this system and all subsystems.
Yields
------
tuple
A tuple of the form (abs_name, comm_size).
"""
if self.comm.size > 1:
for info in self.comm.allgather(list(super().comm_info_iter())):
yield from info
else:
yield from super().comm_info_iter()
def _declared_partials_iter(self):
"""
Iterate over all declared partials.
Yields
------
(key, meta) : (key, dict)
key: a tuple of the form (of, wrt)
meta: a dict containing the partial metadata
"""
if self.comm.size > 1:
if self._gather_full_data():
gathered = self.comm.allgather(self._subjacs_info)
else:
gathered = self.comm.allgather({})
seen = set()
for rankdict in gathered:
for key, meta in rankdict.items():
if key not in seen:
yield key, meta
seen.add(key)
else:
yield from super()._declared_partials_iter()
def _get_missing_partials(self, missing):
"""
Return a list of (of, wrt) tuples for which derivatives have not been declared.
Parameters
----------
missing : dict
Dictionary containing list of missing derivatives keyed by system pathname.
"""
if self.comm.size > 1:
msng = {}
super()._get_missing_partials(msng)
if self._gather_full_data():
gathered = self.comm.allgather(msng)
else:
gathered = self.comm.allgather({})
seen = set()
for rankdict in gathered:
for sysname, mset in rankdict.items():
if sysname not in seen:
missing[sysname] = mset
seen.add(sysname)
else:
super()._get_missing_partials(missing)