-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
collective.py
344 lines (278 loc) · 10.4 KB
/
collective.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
import paddle
# (TODO: GhostScreaming) It will be removed later.
from paddle.fluid import core
from paddle.framework import in_dynamic_mode
from .communication.group import Group, _add_new_group, is_initialized
from .fleet.layers.mpu.mp_ops import _c_concat # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_identity # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_lookup_table # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_split # noqa: F401
from .fleet.layers.mpu.mp_ops import _Linear # noqa: F401
from .fleet.layers.mpu.mp_ops import _linear # noqa: F401
from .fleet.layers.mpu.mp_ops import _mp_allreduce # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_embedding # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_linear # noqa: F401
from .fleet.layers.mpu.mp_ops import _set_var_distributed # noqa: F401
from .fleet.layers.mpu.mp_ops import split # noqa: F401
__all__ = []
_global_env = None
def _get_global_env():
global _global_env
if not _global_env:
_global_env = paddle.distributed.ParallelEnv()
return _global_env
# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
_global_env_gid = 0
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"
_valid_backend_list = ['nccl', 'gloo', 'heter', 'xccl', 'bkcl']
_default_store = None # the default tcp store
_default_backend = None
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
def _set_default_backend(backend):
global _default_backend
_default_backend = backend
def _set_default_store(store):
global _default_store
_default_store = store
def _get_group_map():
global _group_map
if _global_env_gid not in _group_map:
genv = _get_global_env()
_group_map[_global_env_gid] = Group(
genv.rank, 0, list(range(genv.world_size))
)
return _group_map
def _get_global_group():
return _get_group_map()[_global_env_gid]
def _get_group_map_by_name():
global _group_map_by_name
return _group_map_by_name
def _get_default_group():
global _group_map_by_name
assert is_initialized(), (
"Call paddle.distributed.init_parallel_env first "
"to initialize the distributed environment."
)
return _get_group_map_by_name()[_default_group_name]
def _set_group_map(gid, group):
global _group_map
assert gid not in _group_map
_group_map[gid] = group
def _set_group_map_by_name(name, group):
global _group_map_by_name
assert name not in _group_map_by_name
_group_map_by_name[name] = group
def _set_group_map_backend(group, backend):
global _group_map_backend
assert group not in _group_map_backend
_group_map_backend[group] = backend
def _new_ring_id():
# NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
if in_dynamic_mode():
global _start_ring_id
_start_ring_id += 1
return _start_ring_id + max(_get_global_env().nrings, 9)
else:
return len(_get_group_map()) + max(_get_global_env().nrings, 9)
def _new_process_group_impl(
backend,
store,
rank,
world_size,
group_name,
pg_options,
group_id=0,
):
pg = None
genv = _get_global_env()
assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
if backend == "gloo":
pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
elif backend == "nccl":
pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id)
elif backend == "xccl":
pg = core.ProcessGroupCustom.create(
store, genv.device_type, rank, world_size, group_id
)
elif backend == "bkcl":
pg = core.ProcessGroupBKCL.create(store, rank, world_size, group_id)
return pg
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static graph mode.
_custom_gid = None
def _set_custom_gid(gid):
global _custom_gid
_custom_gid = gid
def new_group(ranks=None, backend=None, timeout=_default_timeout):
"""
Creates a new distributed communication group.
Args:
ranks (list): The global ranks of group members.
backend (str): The backend used to create group, only nccl is supported now.
timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
Returns:
Group: The group instance.
Examples:
.. code-block:: python
import paddle
paddle.distributed.init_parallel_env()
tindata = paddle.randn(shape=[2, 3])
gp = paddle.distributed.new_group([2,4,6])
paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
"""
global _custom_gid
global _group_map
if in_dynamic_mode():
global _default_group_name
gid = _custom_gid if _custom_gid else _new_ring_id()
group_name = _default_group_name + str(gid)
if backend != 'heter' and (ranks is None or len(ranks) > 1):
global_group = _get_default_group()
global_rank = global_group.rank
global_ranks = global_group.ranks
backend = _default_backend if backend is None else backend
if ranks is None:
ranks = global_ranks
assert len(ranks) <= len(global_ranks), (
"Size of new group must be less than or "
"equal to that of the default global group."
)
size = len(ranks)
ranks = sorted(ranks)
if size > 1 and global_rank in ranks:
rank = 0 if backend == 'heter' else ranks.index(global_rank)
pg = _new_process_group_impl(
backend,
_default_store,
rank,
size,
group_name,
pg_options=None,
group_id=gid,
)
else:
rank = -1
pg = None
group = Group(rank, gid, ranks, pg=pg, name=group_name)
_group_map_by_name[group_name] = group
_group_map[gid] = group
_group_map_backend[group] = backend
# TODO: The method below is a new method for group management, will replace the previous
# three in the future.
_add_new_group(group)
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by tcp
paddle.distributed.barrier(group=group)
if paddle.distributed.get_world_size() > 1:
paddle.distributed.barrier()
return group
if not backend:
backend = 'nccl'
assert backend == 'nccl', "backend other than nccl is not supported yet"
genv = _get_global_env()
global_rank = genv.rank
ring_id = _new_ring_id()
if global_rank not in ranks:
gp = Group(-1, ring_id, ranks)
_group_map[ring_id] = gp
else:
ranks = sorted(ranks)
group_rank = ranks.index(global_rank)
group_size = len(ranks)
gp = Group(group_rank, ring_id, ranks)
_group_map[ring_id] = gp
if group_size >= 2:
strategy = core.ParallelStrategy()
strategy.nranks = group_size
strategy.local_rank = group_rank
strategy.trainer_endpoints = [
genv.trainer_endpoints[i] for i in ranks
]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy, place).init_with_ring_id(
ring_id
)
elif core.is_compiled_with_xpu():
place = core.XPUPlace(genv.device_id)
core.BKCLParallelContext(strategy, place).init_with_ring_id(
ring_id
)
else:
raise AssertionError("no cuda device found")
else:
return gp
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
tmp = (
paddle.to_tensor([1], dtype="int32")
if in_dynamic_mode()
else paddle.full([0], 1, dtype="int32")
)
paddle.distributed.all_reduce(tmp, sync_op=True)
paddle.distributed.wait(tmp)
return gp
def is_available():
"""
Check whether the distributed package is available.
Returns:
Returns True if the distributed package is available, otherwise False.
Examples:
.. code-block:: python
import paddle
print(paddle.distributed.is_available())
"""
return core.is_compiled_with_dist()
def _init_parallel_env(backend):
master_endpoint = os.getenv("PADDLE_MASTER", None)
if master_endpoint:
master_addr = master_endpoint.split(":")[0]
master_port = int(master_endpoint.split(":")[1])
global_env = _get_global_env()
rank = global_env.rank
world_size = global_env.world_size
dev_id = global_env.device_id
is_master = rank == 0
store = core.TCPStore(
master_addr,
master_port,
is_master,
world_size,
)
if backend == "gloo":
core.CommContextManager.create_gloo_comm_context(
store, 0, rank, world_size
)
elif backend == "nccl":
core.CommContextManager.create_nccl_comm_context(
store, dev_id, 0, rank, world_size
)