/
parallel.py
1285 lines (1064 loc) · 46.4 KB
/
parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
import sys
import time
import warnings
from collections import OrderedDict, namedtuple
from contextlib import contextmanager
from multiprocessing import Manager, Process
import numpy as np
import paddle
from paddle import _legacy_C_ops, framework
from paddle.distributed.collective import (
Group,
_default_group_name,
_get_group_map_by_name,
_new_process_group_impl,
_set_default_backend,
_set_default_store,
_set_group_map,
_set_group_map_backend,
_set_group_map_by_name,
_valid_backend_list,
)
from paddle.distributed.communication.group import (
_add_new_group,
_get_global_group,
is_initialized,
)
from paddle.distributed.fleet.base.private_helper_function import (
wait_server_ready,
)
from paddle.distributed.fleet.launch_utils import check_backend
# (TODO: GhostScreaming) It will be removed later.
from paddle.framework import (
_set_expected_place,
base as imperative_base,
core,
in_dynamic_mode,
)
from paddle.nn.layer import layers
from paddle.utils import deprecated
from . import parallel_helper
from .backup_env import getenv_or_backup
__all__ = []
ParallelStrategy = core.ParallelStrategy
def _build_default_parallel_strategy():
strategy = ParallelStrategy()
strategy.nranks = paddle.distributed.ParallelEnv().nranks
strategy.local_rank = paddle.distributed.ParallelEnv().local_rank
strategy.trainer_endpoints = (
paddle.distributed.ParallelEnv().trainer_endpoints
)
strategy.current_endpoint = (
paddle.distributed.ParallelEnv().current_endpoint
)
return strategy
def _coalesce_tensors(var_groups):
coalesced_grads_and_grad_vars = []
for group_id, grad_vars in var_groups.items():
flattened_vars = []
g_var_shapes = []
for g_var in grad_vars:
g_var_shapes.append(g_var.shape)
flattened_vars.append(
paddle.reshape(x=g_var, shape=[np.prod(g_var.shape)])
)
coalesced_grad = paddle.concat(flattened_vars)
coalesced_grads_and_grad_vars.append(
[coalesced_grad, grad_vars, g_var_shapes]
)
return coalesced_grads_and_grad_vars
@framework.dygraph_only
def _reshape_inplace(x, shape):
x_shape = framework._create_tensor(dtype=x.dtype)
framework._dygraph_tracer().trace_op(
type="reshape2",
inputs={'X': x},
outputs={'Out': x, 'XShape': x_shape},
attrs={'shape': shape},
)
@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
if in_dynamic_mode():
for (
coalesced_grad,
origin_grad_vars,
grad_shapes,
) in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
attrs = ()
attrs += ('sections', grad_var_len)
attrs += ('axis', 0)
_legacy_C_ops.split(coalesced_grad, origin_grad_vars, *attrs)
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
g_var.reshape_(shape=g_shape)
assert g_var.shape == g_shape
@imperative_base.no_grad
@framework.dygraph_only
def build_groups(vars, group_size):
group_idx = 0
memory_counter = 0
var_groups = OrderedDict()
dtype = vars[0].dtype
for var in vars:
bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
if memory_counter < group_size and dtype == var.dtype:
memory_counter += bytes
else:
memory_counter = bytes
dtype = var.dtype
group_idx += 1
var_groups.setdefault(group_idx, []).append(var)
return _coalesce_tensors(var_groups)
@imperative_base.no_grad
@framework.dygraph_only
def sync_params_buffers(
model,
comm_group=None,
src_rank=0,
is_model_parallel=False,
fuse_params=True,
):
model_vars = []
for _, param in model._obtain_parameters_buffers().items():
if not isinstance(param, core.eager.Tensor):
raise TypeError(
"The data type of '%s' must be core.eager.Tensor" % param.name
)
if is_model_parallel:
if hasattr(param, "is_distributed") and param.is_distributed:
continue
# NOTE(shenliang03): Support situations that do not require synchronization parameters,
# such as moe's expert parameters
if getattr(param, "no_sync", False):
continue
if param.type == core.VarDesc.VarType.VOCAB:
continue
model_vars.append(param.detach())
if len(model_vars) == 0:
return
if fuse_params:
# group size is 128M
coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
for coalesced_var, _, _ in coalesced_vars:
paddle.distributed.broadcast(
coalesced_var, src=src_rank, group=comm_group, sync_op=True
)
for coalesced_var, origin_vars, var_shapes in coalesced_vars:
var_len = [np.prod(v_shape) for v_shape in var_shapes]
paddle.base.framework._dygraph_tracer().trace_op(
type='split',
inputs={'X': coalesced_var},
outputs={'Out': origin_vars},
attrs={'sections': var_len, 'axis': 0},
)
else:
for var in model_vars:
paddle.distributed.broadcast(
var, src=src_rank, group=comm_group, sync_op=True
)
class DataParallel(layers.Layer):
"""
Run the dygraph module with data parallelism.
Currently, DataParallel class only supports to run the dynamic graph
with multi-process.
Now supports two ways to start training:
1. start by ``paddle.distributed.spawn`` method, for example:
``python demo.py`` (spawn need to be called in ``__main__`` method)
2. start by ``paddle.distributed.launch`` module, for example:
``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
And the content of `demo.py` is the code of examples.
Args:
layers(Layer): The module that should be executed by data parallel.
strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism,
contains environment configuration related to parallel execution. Default: None.
comm_buffer_size(int, optional): It limits the memory size(MB) of one buffer
parameters' gradient which is the input of communication
calling(e.g NCCLAllReduce). Default: 25.
last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
calling. Making the last communication buffer size small is useful to
improve performance. Default: 1.
find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the
all tensors in the return value of the wrapped model's
forward function. For parameters not involved in loss
calculation, their gradients will be marked as ready in
advance to prepare reduce. Please note that all forward
outputs derived from the wrapped model parameters must
participate in the calculation of loss and subsequent
gradient calculations. If not, serious error will occur.
Note that setting the find_unused_parameters to True
will affect computing performance. Therefore, if all parameters
are sure to participate in the loss calculation and the
autograd graph construction, please set it False. Default: False.
Returns:
Layer: The data paralleled module.
Examples:
.. code-block:: python
:name: dp-example
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> import paddle
>>> import paddle.nn as nn
>>> import paddle.optimizer as opt
>>> import paddle.distributed as dist
>>> class LinearNet(nn.Layer):
... def __init__(self):
... super().__init__()
... self._linear1 = nn.Linear(10, 10)
... self._linear2 = nn.Linear(10, 1)
... def forward(self, x):
... return self._linear2(self._linear1(x))
>>> def train():
... # 1. initialize parallel environment
... dist.init_parallel_env()
... # 2. create data parallel layer & optimizer
... layer = LinearNet()
... dp_layer = paddle.DataParallel(layer)
... loss_fn = nn.MSELoss()
... adam = opt.Adam(
... learning_rate=0.001, parameters=dp_layer.parameters())
... # 3. run layer
... inputs = paddle.randn([10, 10], 'float32')
... outputs = dp_layer(inputs)
... labels = paddle.randn([10, 1], 'float32')
... loss = loss_fn(outputs, labels)
... loss.backward()
... adam.step()
... adam.clear_grad()
>>> if __name__ == '__main__':
... # 1. start by ``paddle.distributed.spawn`` (default)
... dist.spawn(train, nprocs=2)
... # 2. start by ``paddle.distributed.launch``
... # train()
.. note::
``PyLayer`` is not supported in DataParallel. To solve problems of this kind,
it's recommended to skip gradient synchronization among multiple cards by 'no_sync',
and manually implement 'all_reduce' before model optimization. There is an example
showing specific implemetation processing.
Examples:
.. code-block:: python
:name: dp-pylayer-example
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> import numpy
>>> import paddle
>>> import paddle.distributed as dist
>>> from paddle.autograd import PyLayer
>>> from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
>>> class cus_tanh(PyLayer):
... @staticmethod
... def forward(ctx, x):
... y = paddle.tanh(x)
... ctx.save_for_backward(y)
... return y
... @staticmethod
... def backward(ctx, dy):
... y, = ctx.saved_tensor()
... grad = dy * (1 - paddle.square(y))
... return grad
>>> class SimpleNet(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.linear = paddle.nn.Linear(2, 2)
... def forward(self, inputs):
... inputs = cus_tanh.apply(inputs)
... return self.linear(inputs)
>>> if __name__ == '__main__':
... dist.init_parallel_env()
... model = SimpleNet()
... model = paddle.DataParallel(model)
... opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
... for step in range(10):
... x_data = numpy.random.randn(2, 2).astype(numpy.float32)
... x = paddle.to_tensor(x_data)
... x.stop_gradient = False
... # step 1 : skip gradient synchronization by 'no_sync'
... with model.no_sync():
... y_pred = model(x)
... loss = y_pred.mean()
... loss.backward()
... # step 2 : fuse + allreduce manually before optimization
... fused_allreduce_gradients(list(model.parameters()), None)
... opt.step()
... opt.clear_grad()
"""
def __init__(
self,
layers,
strategy=None,
comm_buffer_size=25,
last_comm_buffer_size=1,
find_unused_parameters=False,
group=None,
):
super().__init__(layers.full_name() + "_data_parallel")
assert (
in_dynamic_mode()
), "It's not supported to construct DataParallel in static graph mode."
self._layers = layers
self.find_unused_parameters = find_unused_parameters
self.grad_need_sync = True
self.group = group
self.var_dtype = core.eager.Tensor
# NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy.
# It just stores some environment variables, which can be constructed by
# ParallelEnv. Here it is set as an optional argument.
# This parameter is not removed because of compatibility with 1.x writing.
if strategy is not None:
self._strategy = strategy
else:
self._strategy = _build_default_parallel_strategy()
if self._strategy.nranks > 1:
# check the environment
assert parallel_helper.__parallel_ctx__clz__ is not None, (
"ParallelContext must be initialized before. You should use init_parallel_env() before"
"constructing the DataParallel."
)
if in_dynamic_mode():
self.group = (
paddle.distributed.collective._get_default_group()
if self.group is None
else self.group
)
assert isinstance(
self.group, paddle.distributed.collective.Group
), "ProcessGroup must be an instance of Group in DataParallel."
# sync buffer and params
sync_params_buffers(self._layers, fuse_params=False)
self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
# NOTE(shenliang03): We can set environment variables to control
# the size of the group, Default: 1MB. The role of this small group is:
# when the last group allreduce, the overlap cannot work. Making the
# the last group small is useful to improve performance.
self.last_comm_buffer_size = int(
last_comm_buffer_size * 1024 * 1024
)
self.init_reducer()
else:
warnings.warn(
"The program will return to single-card operation. "
"Please check 1, whether you use spawn or fleetrun "
"to start the program. 2, Whether it is a multi-card "
"program. 3, Is the current environment multi-card."
)
def init_reducer(self):
layers_param = []
params_set = set()
for sublayer in self.sublayers():
for _, param in sublayer.named_parameters(include_sublayers=False):
if param is None or param in params_set:
continue
params_set.add(param)
if not isinstance(param, self.var_dtype):
raise TypeError(
f"The data type of '{param.name}' must be '{self.var_dtype}'"
)
if param.trainable:
layers_param.append((sublayer, param))
trainable_parameters = list(
filter(
lambda x: not getattr(x, "no_sync", False),
[param for _, param in layers_param],
)
)
assert len(trainable_parameters) > 0, (
"This model does not have any parameters to train, and "
"does not need to use DataParallel"
)
# NOTE(shenliang03): Here we can only use the attributes to judge whether
# parameter is sparse(or SelectedRows). The reason is that the sparse message
# can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
# we should add the layer here like "paddle.nn.layer.common.Embedding".
def check_layer_sparse(sublayer):
if isinstance(sublayer, paddle.nn.layer.common.Embedding):
return sublayer._sparse
return False
is_sparse_gradient = [
check_layer_sparse(sublayer) for sublayer, _ in layers_param
]
if in_dynamic_mode():
self.group_indices = core.eager_assign_group_by_size(
trainable_parameters,
is_sparse_gradient,
[self.last_comm_buffer_size, self.comm_buffer_size],
)
self._reducer = core.EagerReducer(
trainable_parameters,
list(reversed(self.group_indices)),
is_sparse_gradient,
self.group.process_group,
[self.last_comm_buffer_size, self.comm_buffer_size],
self.find_unused_parameters,
)
def _find_tensor(self, obj):
var_type = core.eager.Tensor
if isinstance(obj, var_type):
return [obj]
if isinstance(obj, (list, tuple)):
return itertools.chain(*map(self._find_tensor, obj))
if isinstance(obj, dict):
return itertools.chain(*map(self._find_tensor, obj.values()))
return []
@contextmanager
def no_sync(self):
"""
A context manager to stop gradient synchronization. Within no_sync(),
gradients of parameters will only be accumulated on model and not
synchronized util the first forward-backward out of this context.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> import paddle
>>> import paddle.nn as nn
>>> import paddle.distributed as dist
>>> class SimpleNet(nn.Layer):
... def __init__(self):
... super().__init__()
... self._linear = nn.Linear(10, 1)
... def forward(self, x):
... return self._linear(x)
>>> dist.init_parallel_env()
>>> model = SimpleNet()
>>> dp_model = paddle.DataParallel(model)
>>> inputs_1 = paddle.randn([10, 10], 'float32')
>>> inputs_2 = paddle.ones([10, 10], 'float32')
>>> with dp_model.no_sync():
... # gradients will not be synchronized
... dp_model(inputs_1).backward()
>>> # synchronization happens here
>>> dp_model(inputs_2).backward()
"""
tmp_grad_need_sync = self.grad_need_sync
self.grad_need_sync = False
try:
yield
finally:
self.grad_need_sync = tmp_grad_need_sync
def forward(self, *inputs, **kwargs):
outputs = self._layers(*inputs, **kwargs)
if (
self._strategy.nranks > 1
and framework._dygraph_tracer()._has_grad
and self.grad_need_sync
):
self._reducer.prepare_for_backward(list(self._find_tensor(outputs)))
return outputs
@deprecated(
since="2.0.0", reason="This method does not need to be called anymore."
)
def scale_loss(self, loss):
"""
Deprecated method, now ``scale_loss`` is an empty method,
keep this method just for compatibility.
"""
return loss
@deprecated(
since="2.0.0", reason="This method does not need to be called anymore."
)
def apply_collective_grads(self):
"""
Deprecated method, now ``apply_collective_grads`` is an empty method,
keep this method just for compatibility.
"""
return
def state_dict(
self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
):
'''
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
Parameters:
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
Retruns:
dict: a dict contains all the parameters and persistable buffers.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> import paddle
>>> import paddle.distributed as dist
>>> dist.init_parallel_env()
>>> emb = paddle.nn.Embedding(10, 10)
>>> emb = paddle.DataParallel(emb)
>>> state_dict = emb.state_dict()
>>> paddle.save(state_dict, "paddle_dy.pdparams")
'''
return self._layers.state_dict(
destination=destination,
include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix,
)
@framework.deprecate_stat_dict
def set_state_dict(self, state_dict, use_structured_name=True):
'''
Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
Parameters:
state_dict(dict) : Dict contains all the parameters and persistable buffers.
use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
Default: True
Returns:
None
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> import paddle
>>> import paddle.distributed as dist
>>> dist.init_parallel_env()
>>> emb = paddle.nn.Embedding(10, 10)
>>> emb = paddle.DataParallel(emb)
>>> state_dict = emb.state_dict()
>>> paddle.save(state_dict, "paddle_dy.pdparams")
>>> para_state_dict = paddle.load("paddle_dy.pdparams")
>>> emb.set_state_dict(para_state_dict)
'''
self._layers.set_state_dict(
state_dict, use_structured_name=use_structured_name
)
# [aliases] Compatible with old method names
set_dict = set_state_dict
load_dict = set_state_dict
# NOTE(chenweihang): Maintain a global parallel env to avoid
# initializing ParallelEnv every time and improve performance
_global_parallel_env = None
class ParallelEnv:
"""
.. note::
This API is not recommended, if you need to get rank and world_size,
it is recommended to use ``paddle.distributed.get_rank()`` and
``paddle.distributed.get_world_size()`` .
This class is used to obtain the environment variables required for
the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
or ``paddle.distributed.spawn`` .
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> import paddle
>>> import paddle.distributed as dist
>>> def train():
... # 1. initialize parallel environment
... dist.init_parallel_env()
... # 2. get current ParallelEnv
... parallel_env = dist.ParallelEnv()
... print("rank: ", parallel_env.rank)
... print("world_size: ", parallel_env.world_size)
>>> if __name__ == '__main__':
... # 1. start by ``paddle.distributed.spawn`` (default)
... dist.spawn(train, nprocs=2)
... # 2. start by ``paddle.distributed.launch``
... train()
# Print result in process 1:
rank: 1
world_size: 2
# Print result in process 2:
rank: 2
world_size: 2
"""
def __init__(self):
self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self._device_type = str(os.getenv("PADDLE_XCCL_BACKEND", ""))
self._pg_timeout = int(os.getenv("PADDLE_PG_TIMEOUT", "1800000"))
# imperative only support one gpu or xpu
if self._device_type != "":
FLAGS_selected_custom_devices = (
f'FLAGS_selected_{self._device_type}s'
)
selected_custom_devices = os.getenv(
FLAGS_selected_custom_devices, "0"
).split(",")
self._device_id = int(selected_custom_devices[0])
else:
if core.is_compiled_with_cuda():
selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
self._device_id = int(selected_gpus[0])
elif core.is_compiled_with_xpu():
selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
self._device_id = int(selected_xpus[0])
self._trainer_endpoints = getenv_or_backup(
"PADDLE_TRAINER_ENDPOINTS", ""
).split(",")
self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
assert (
self._nrings > 0
), "nccl_nrings must be an integer greater than 0."
assert (
self._nrings < 9
), "nccl_nrings should be less than 9, which is enough in most scenarios."
@property
def rank(self):
"""
Rank of current trainer.
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # execute this command in terminal: export PADDLE_TRAINER_ID=0
>>> import paddle.distributed as dist
>>> env = dist.ParallelEnv()
>>> print("The rank is %d" % env.rank)
The rank is 0
"""
return self._rank
@property
def world_size(self):
"""
The number of trainers (number of processes participating in current job).
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
>>> import paddle.distributed as dist
>>> env = dist.ParallelEnv()
>>> print("The world_size is %d" % env.world_size)
The world_size is 4
"""
return self._world_size
@property
def device_id(self):
"""
The ID of selected GPU card for parallel training.
Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # execute this command in terminal: export FLAGS_selected_gpus=1
>>> import paddle.distributed as dist
>>> env = dist.ParallelEnv()
>>> print("The device id are %d" % env.device_id)
The device id are 1
"""
return self._device_id
@property
def device_type(self):
"""
The type of custom device for parallel training.
Its value is equal to the value of the environment variable ``PADDLE_XCCL_BACKEND`` . The default value is None.
"""
return self._device_type
@property
def current_endpoint(self):
"""
The endpoint of current trainer, it is in the form of (node IP + port).
Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
>>> import paddle.distributed as dist
>>> env = dist.ParallelEnv()
>>> print("The current endpoint are %s" % env.current_endpoint)
The current endpoint are 127.0.0.1:6170
"""
return self._current_endpoint
@property
def trainer_endpoints(self):
"""
The endpoints of all trainer nodes in the task,
which are used to broadcast the NCCL ID when NCCL2 is initialized.
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
>>> import paddle.distributed as dist
>>> env = dist.ParallelEnv()
>>> print("The trainer endpoints are %s" % env.trainer_endpoints)
The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
"""
return self._trainer_endpoints
@property
def nrings(self):
"""
Nrings of current trainer.
Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # execute this command in terminal: export FLAGS_nccl_nrings=1
>>> import paddle.distributed as dist
>>> env = dist.ParallelEnv()
>>> print("The nrings is %d" % env.nrings)
The nrings is 1
"""
return self._nrings
@property
def pg_timeout(self):
"""
timeout of process group.
Its value is equal to the value of the environment variable ``PADDLE_PG_TIMEOUT`` . The default value is 30 minutes.
Examples:
.. code-block:: python
>>> # execute this command in terminal: export PADDLE_PG_TIMEOUT=1800000
>>> import paddle.distributed as dist
>>> env = dist.ParallelEnv()
>>> # the pg_timeout of process group 1800000
"""
return self._pg_timeout
# [aliases] Compatible with old method names
local_rank = rank
nranks = world_size
dev_id = device_id
def _get_global_parallel_env():
global _global_parallel_env
if _global_parallel_env is None:
_global_parallel_env = ParallelEnv()
return _global_parallel_env
def _start_kv_server(port, http_server_d, size):
from paddle.distributed.fleet.utils.http_server import KVServer
http_server = KVServer(int(port), size=size)
http_server.start()
wait_seconds = 3
while http_server_d.get("running", False) or not http_server.should_stop():
time.sleep(wait_seconds)
http_server.stop()
def _is_cpuonly(backend):
check_backend(backend)
if (
backend in ['auto', 'nccl', 'bkcl', 'heter']
and (core.is_compiled_with_cuda() or core.is_compiled_with_xpu())
) or backend == 'xccl':
# passes 'auto' and can use cuda or xpu, use the default logics. so return False
return False
else:
return True
def _check_var_exists(var_name):
var = getenv_or_backup(var_name, None)
if var is None:
raise ValueError(
"paddle.distributed initialize error, "
"environment variable %s is needed, but not set." % var_name
)
def _get_modified_flags():
ret = []
FLAGS = namedtuple('FLAGS', ['name', 'current_value', 'default_value'])
global_flags = core.globals()
for key in global_flags.keys():
value = global_flags.get(key)
default_value = global_flags.get_default(key)
if not value == default_value:
ret.append(FLAGS(key, value, default_value))
return ret
def _print_modified_flags(modified_flags):
if len(modified_flags) > 0:
sys.stderr.write(
"======================= Modified FLAGS detected =======================\n"
)
for flag in modified_flags:
sys.stderr.write(str(flag))
sys.stderr.write("\n")
sys.stderr.write(
"=======================================================================\n"
)
def init_parallel_env():
"""
Initialize parallel training environment in dynamic graph mode.
Note:
Now initialize both `NCCL` and `GLOO` contexts for communication.
Args:
backend (string): A string represents the backend used by DataParallel,
should be one of 'gloo'(for cpu), 'nccl'(for cuda), 'bkcl'(for xpu), 'auto'(auto detect).
The auto detection prefer 'nccl', 'bkcl' than 'gloo'.
Returns:
None
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:GPU, env:DISTRIBUTED)
>>> import paddle
>>> import paddle.nn as nn
>>> import paddle.optimizer as opt
>>> import paddle.distributed as dist
>>> class LinearNet(nn.Layer):
... def __init__(self):
... super().__init__()
... self._linear1 = nn.Linear(10, 10)
... self._linear2 = nn.Linear(10, 1)
... def forward(self, x):
... return self._linear2(self._linear1(x))
>>> def train():
... # 1. initialize parallel environment
... dist.init_parallel_env()
... # 2. create data parallel layer & optimizer
... layer = LinearNet()
... dp_layer = paddle.DataParallel(layer)
... loss_fn = nn.MSELoss()
... adam = opt.Adam(
... learning_rate=0.001, parameters=dp_layer.parameters())
... # 3. run layer
... inputs = paddle.randn([10, 10], 'float32')
... outputs = dp_layer(inputs)
... labels = paddle.randn([10, 1], 'float32')
... loss = loss_fn(outputs, labels)
... loss.backward()
... adam.step()
... adam.clear_grad()
>>> if __name__ == '__main__':
... dist.spawn(train)
"""
modified_flags = _get_modified_flags()
_print_modified_flags(modified_flags)