/
communication.py
1082 lines (827 loc) · 33.6 KB
/
communication.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import logging
import multiprocessing
import bagua_core as B
from bagua.service import AutotuneService
from . import env
from .env import (
get_master_addr,
get_world_size,
get_rank,
get_local_rank,
get_local_size,
get_default_bucket_size,
get_bagua_service_port,
)
from enum import IntEnum
from .utils import flatten, unflatten
import torch
import torch.distributed as dist
from bagua.service.autotune_service import AutotuneClient
from functools import lru_cache
from datetime import timedelta
from typing import Optional, List
# fmt: off
__all__ = [
"ReduceOp", "new_group", "from_torch_group", "init_process_group",
"is_initialized", "send", "recv", "broadcast", "reduce", "reduce_inplace",
"allreduce", "allreduce_inplace", "allgather", "allgather_inplace",
"gather", "gather_inplace", "scatter", "scatter_inplace",
"reduce_scatter", "reduce_scatter_inplace", "alltoall", "alltoall_inplace"
]
# Process group's global rank to local rank mapping
_pg_group_ranks = {}
# Process group's name to BaguaProcessGroup
_pg_map = {}
# Default process group state
_default_pg = None
# Default store
_default_store = None
# Process group count for default naming
_group_count = 0
# must be consistent with Aluminum ReductionOperator: https://github.com/BaguaSys/Aluminum/blob/master/include/aluminum/base.hpp
class ReduceOp(IntEnum):
"""An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``, ``MIN``, ``MAX``, ``BAND``,
``BOR``, ``BXOR`` and ``AVG``."""
SUM = 0
PRODUCT = 1
MIN = 2
MAX = 3
BOR = 7
BAND = 8
BXOR = 9
AVG = 10
def _check_default_pg():
"""
Helper that checks if the default process group has been initialized, with
assertion.
"""
assert _default_pg is not None, "Default process group is not initialized"
def is_initialized():
"""
Checking if the default process group has been initialized.
"""
return _default_pg is not None
def _get_default_group():
"""
Getting the default process group created by :func:`init_process_group`.
"""
if not is_initialized():
raise RuntimeError(
"Default process group has not been initialized, "
"please make sure to call init_process_group."
)
return _default_pg
def _rank_not_in_comm(comm: Optional[B.BaguaSingleCommunicatorPy] = None):
"""
Return ``True`` if the current process's rank is not in a given communicator.
"""
if comm is None:
return False
return comm == CommMember.NON_COMM_MEMBER
def _bagua_backend_comm(comm: Optional[B.BaguaSingleCommunicatorPy] = None):
"""
Return ``None`` if the current process's rank is not in a given communicator.
Otherwise return the communicator passed in.
"""
if _rank_not_in_comm(comm):
return None
return comm
def new_group(
ranks: Optional[List[int]] = None, stream: Optional[torch.cuda.Stream] = None
):
"""
Creates a new process group.
This function requires that all processes in the default group (i.e. all
processes that are part of the distributed job) enter this function, even
if they are not going to be members of the group. Additionally, groups
should be created in the same order in all processes.
Each process group will create three communicators on request, a global communicator,
a inter-node communicator and a intra-node communicator. Users can access them through
``group.get_global_communicator()``, ``group.get_inter_node_communicator()``
and ``group.get_intra_node_communicator()`` respectively.
Args:
ranks: List of ranks of group members. If ``None``, will be
set to all ranks. Default is ``None``.
stream: A CUDA stream used to execute NCCL operations. If ``None``,
CUDA stream of the default group will be used. See
`CUDA semantics <https://pytorch.org/docs/stable/notes/cuda.html?highlight=stream>`_
for details.
Returns:
A handle of process group that can be given to collective calls.
.. note::
The global communicator is used for global communications involving all ranks in the process group.
The inter-node communicator and the intra-node communicator is used for hierarchical communications
in this process group.
.. note::
For a specific communicator ``comm``, ``comm.rank()`` returns the rank of current process and
``comm.nranks()`` returns the size of the communicator.
"""
global _group_count
global _pg_group_ranks
global _pg_map
_group_count += 1
if ranks is None:
ranks = list(range(get_world_size()))
else:
# sanity check for the input ranks
for rank in ranks:
if rank < 0 or rank >= get_world_size():
raise ValueError(
"Invalid rank {}, should be non-negative and less than world size {}.",
rank,
get_world_size(),
)
ranks = sorted(ranks)
if stream is None:
_check_default_pg()
stream = _get_default_group().stream
group_name = str(_group_count)
pg = BaguaProcessGroup(ranks, stream, str(_group_count))
# Create the global rank to group rank mapping
_pg_group_ranks[pg] = {
global_rank: group_rank for group_rank, global_rank in enumerate(ranks)
}
_pg_map[group_name] = pg
return pg
def from_torch_group(group, stream: Optional[torch.cuda.Stream] = None):
"""
Convert a Pytorch process group to its equivalent Bagua process group.
Args:
group: A handle of the Pytorch process group.
stream: A CUDA stream used to execute NCCL operations. If ``None``,
CUDA stream of the default group will be used. See :func:`new_group`
for more information.
Returns:
A handle of the Bagua process group.
"""
import torch.distributed.distributed_c10d as c10d
ranks = list(c10d._pg_group_ranks[group].keys())
return new_group(ranks, stream)
class BaguaProcessGroup:
def __init__(self, ranks, stream, group_name):
self.ranks = ranks
self.stream = stream
self.group_name = group_name
self.intra_ranks = list(
filter(
lambda rank: rank // get_local_size() == get_rank() // get_local_size(),
ranks,
)
)
self.inter_ranks = list(
filter(
lambda rank: rank % get_local_size() == ranks[0] % get_local_size(),
ranks,
)
)
logging.debug(f"Initialize Bagua process group of ranks {self.ranks}")
def get_global_communicator(self):
return get_communicator(self.group_name, "global")
def get_inter_node_communicator(self):
return get_communicator(self.group_name, "inter")
def get_intra_node_communicator(self):
return get_communicator(self.group_name, "intra")
@lru_cache(maxsize=None)
def get_communicator(group_name: str, comm_name: str):
global _pg_map
pg = _pg_map[group_name]
if comm_name == "global":
ranks = pg.ranks
elif comm_name == "inter":
ranks = pg.inter_ranks
elif comm_name == "intra":
ranks = pg.intra_ranks
else:
raise ValueError("comm_name should be one of ['global', 'inter', 'intra']")
comm_key = "{}_{}_{}".format(group_name, comm_name, ",".join(map(str, ranks)))
nccl_unique_id = broadcast_nccl_unique_id(comm_key, root=ranks[0])
if get_rank() not in ranks:
return CommMember.NON_COMM_MEMBER
rank = ranks.index(get_rank())
nranks = len(ranks)
comm = B.BaguaSingleCommunicatorPy(
rank=rank,
nranks=nranks,
device_id=get_local_rank(),
stream_ptr=pg.stream.cuda_stream,
nccl_unique_id_str=nccl_unique_id,
)
logging.debug(
"init bagua communicator %s-%s ok, global rank: %s rank: %s",
group_name,
comm_name,
get_rank(),
comm.rank(),
)
comm.cuda_stream = pg.stream
return comm
@lru_cache(maxsize=None)
def get_hyperparameters_service_client():
hyperparameters_service_client = AutotuneClient(
get_master_addr(), get_bagua_service_port()
)
return hyperparameters_service_client
@lru_cache(maxsize=None)
def get_backend(model_name: str):
backend = B.BaguaCommBackendPy(100, device_id=get_local_rank())
backend.model_name = model_name
return backend
def run_flask_app():
from flask import Flask
import os
os.environ["WERKZEUG_RUN_MAIN"] = "true"
autotune_service = AutotuneService(
world_size=get_world_size(),
autotune_level=env.get_autotune_level(),
max_samples=env.get_autotune_max_samples(),
sampling_confidence_time_s=env.get_autotune_sampling_confidence_time_s(),
warmup_time_s=env.get_autotune_warmup_time_s(),
is_output_autotune_log=env.get_is_output_autotune_log(),
default_bucket_size=get_default_bucket_size(),
)
app = Flask(__name__)
app = autotune_service.setup_app(app)
log = logging.getLogger("werkzeug")
log.setLevel(logging.ERROR)
app.run(
host="0.0.0.0",
port=get_bagua_service_port(),
debug=False,
use_debugger=False,
use_reloader=False,
)
_autotune_server = None
def start_autotune_server():
"""Starts autotune server in background."""
global _autotune_server
_autotune_server = multiprocessing.Process(target=run_flask_app)
_autotune_server.daemon = True
_autotune_server.start()
def init_process_group(store: Optional[torch.distributed.Store] = None):
"""Initializes the PyTorch builtin distributed process group, and this will
also initialize the distributed package, should be executed before all the
APIs of Bagua.
Args:
store: Key/value store accessible to all workers, used to exchange
connection/address information. If ``None``, a TCP-based store will be created.
Default: ``None``.
Examples::
>>> import torch
>>> import bagua.torch_api as bagua
>>>
>>> torch.cuda.set_device(bagua.get_local_rank())
>>> bagua.init_process_group()
>>>
>>> model = torch.nn.Sequential(
... torch.nn.Linear(D_in, H),
... torch.nn.ReLU(),
... torch.nn.Linear(H, D_out),
... )
>>> optimizer = torch.optim.SGD(
... model.parameters(),
... lr=0.01,
... momentum=0.9
... )
>>> model = model.with_bagua([optimizer], ...)
"""
if get_rank() == 0 and _autotune_server is None:
start_autotune_server()
global _default_pg
global _default_store
if _default_pg is not None:
raise RuntimeError("trying to initialize the default process group twice!")
if _default_store is not None:
raise RuntimeError("The default store has been initialized else where!")
if store is None:
timeout = timedelta(minutes=30)
store, _, _ = next(torch.distributed.rendezvous(url="env://", timeout=timeout))
store.set_timeout(timeout)
_default_store = store
else:
_default_store = store
# TODO remove the dependency on torch process group
if not dist.is_initialized():
torch.distributed.init_process_group(
backend="nccl",
store=_default_store,
rank=get_rank(),
world_size=get_world_size(),
) # fmt: off
_default_pg = new_group(stream=torch.cuda.Stream(priority=-1))
def broadcast_nccl_unique_id(comm_key: str, root):
global _default_store
if get_rank() == root:
idstr = B.BaguaSingleCommunicatorPy.generate_nccl_unique_id_str()
_default_store.set(comm_key, idstr)
else:
idstr = _default_store.get(comm_key)
idstr = str(idstr, encoding="utf-8")
return idstr
class comm(object):
WORLD = object()
class CommMember(object):
# Alias to group.WORLD for backward compatibility
WORLD = comm.WORLD
NON_COMM_MEMBER = object()
def send(tensor: torch.Tensor, dst: int, comm: Optional[B.BaguaSingleCommunicatorPy] = None):
r"""Sends a tensor to :attr:`dst` synchronously.
Args:
tensor: Tensor to send.
dst: Destination rank.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
"""
if _rank_not_in_comm(comm):
return
assert tensor.device != torch.device("cpu"), "input tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.send(tensor.to_bagua_tensor().bagua_backend_tensor(), dst)
comm.cuda_stream.synchronize()
def recv(tensor: torch.Tensor, src: int, comm: Optional[B.BaguaSingleCommunicatorPy] = None):
r"""Receives a tensor synchronously.
Args:
tensor: Tensor to fill with received data.
src: Source rank.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
"""
if _rank_not_in_comm(comm):
return
assert tensor.device != torch.device("cpu"), "input tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.recv(tensor.to_bagua_tensor().bagua_backend_tensor(), src)
comm.cuda_stream.synchronize()
def broadcast_coalesced(tensors, src=0, comm: Optional[B.BaguaSingleCommunicatorPy] = None):
if _rank_not_in_comm(comm):
return
for tensor in tensors:
assert tensor.device != torch.device(
"cpu"
), "input tensors must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
coalesced = flatten(tensors)
comm.broadcast(coalesced.to_bagua_tensor().bagua_backend_tensor(), src)
for buf, synced in zip(tensors, unflatten(coalesced, tensors)):
buf.copy_(synced)
# TODO: remove
comm.cuda_stream.synchronize()
def broadcast(tensor: torch.Tensor, src: int = 0, comm: Optional[B.BaguaSingleCommunicatorPy] = None):
r"""Broadcasts the tensor to all processes associated with the communicator.
:attr:`tensor` must have the same number of elements in all processes
participating in the collective.
Args:
tensor: Data to be sent if :attr:`src` is the rank of
current process, and tensor to be used to save received data
otherwise.
src: Source rank. Default: 0.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
"""
if _rank_not_in_comm(comm):
return
assert tensor.device != torch.device("cpu"), "input tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.broadcast(tensor.to_bagua_tensor().bagua_backend_tensor(), src)
# TODO: remove
comm.cuda_stream.synchronize()
def reduce(
send_tensor: torch.Tensor,
recv_tensor: torch.Tensor,
dst: int,
op: ReduceOp = ReduceOp.SUM,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
r"""Reduces the tensor data across all processes.
Only the process whit rank :attr:`dst` is going to receive the final result.
Args:
send_tensor: Input of the collective.
recv_tensor: Output of the collective, must have the same size with :attr:`send_tensor`.
dst: Destination rank.
op: One of the values from :class:`ReduceOp`
enum. Specifies an operation used for element-wise reductions.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
"""
if _rank_not_in_comm(comm):
return
assert send_tensor.device != torch.device(
"cpu"
), "send tensor must be CUDA and dense"
assert recv_tensor.device != torch.device(
"cpu"
), "recv tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.reduce(
send_tensor.to_bagua_tensor().bagua_backend_tensor(),
recv_tensor.to_bagua_tensor().bagua_backend_tensor(),
dst,
int(op),
)
comm.cuda_stream.synchronize()
def reduce_inplace(
tensor: torch.Tensor, dst: int, op: ReduceOp = ReduceOp.SUM, comm: Optional[B.BaguaSingleCommunicatorPy] = None
):
r"""The in-place version of :func:`reduce`."""
if _rank_not_in_comm(comm):
return
assert tensor.device != torch.device("cpu"), "input tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.reduce_inplace(
tensor.to_bagua_tensor().bagua_backend_tensor(), dst, int(op)
)
comm.cuda_stream.synchronize()
def allreduce_coalesced_inplace(
tensors,
op: ReduceOp = ReduceOp.SUM,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
if _rank_not_in_comm(comm):
return
for tensor in tensors:
assert tensor.device != torch.device(
"cpu"
), "input tensors must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
coalesced = flatten(tensors)
comm.allreduce_inplace(
coalesced.to_bagua_tensor("allreduce_coalesced"), int(op)
)
for buf, synced in zip(tensors, unflatten(coalesced, tensors)):
buf.copy_(synced)
# TODO: remove
comm.cuda_stream.synchronize()
def allreduce(
send_tensor: torch.Tensor,
recv_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""Reduces the tensor data across all processes associated with the communicator in such a way that all get
the final result. After the call :attr:`recv_tensor` is going to be bitwise identical
in all processes.
Args:
send_tensor: Input of the collective.
recv_tensor: Output of the collective, must have the same size with :attr:`send_tensor`.
op: One of the values from :class:`ReduceOp` enum. Specifies an operation used for element-wise reductions.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
Examples::
>>> from bagua.torch_api import allreduce
>>>
>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> send_tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> recv_tensor = torch.zeros(2, dtype=torch.int64)
>>> send_tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> allreduce(send_tensor, recv_tensor)
>>> recv_tensor
tensor([4, 6]) # Rank 0
tensor([4, 6]) # Rank 1
>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> send_tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> recv_tensor = torch.zeros(2, dtype=torch.cfloat)
>>> send_tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> allreduce(send_tensor, recv_tensor)
>>> recv_tensor
tensor([4.+4.j, 6.+6.j]) # Rank 0
tensor([4.+4.j, 6.+6.j]) # Rank 1
"""
if _rank_not_in_comm(comm):
return
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
assert send_tensor.device != torch.device(
"cpu"
), "send tensor must be CUDA and dense"
assert recv_tensor.device != torch.device(
"cpu"
), "recv tensor must be CUDA and dense"
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.allreduce(
send_tensor.to_bagua_tensor().bagua_backend_tensor(),
recv_tensor.to_bagua_tensor().bagua_backend_tensor(),
int(op),
)
# TODO: remove
comm.cuda_stream.synchronize()
def allreduce_inplace(
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""The in-place version of :func:`allreduce`."""
if _rank_not_in_comm(comm):
return
assert tensor.device != torch.device("cpu"), "input tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.allreduce_inplace(tensor.to_bagua_tensor().bagua_backend_tensor(), int(op))
comm.cuda_stream.synchronize()
def allgather(
send_tensor: torch.Tensor,
recv_tensor: torch.Tensor,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""Gathers send tensors from all processes associated with the communicator into :attr:`recv_tensor`.
Args:
send_tensor (torch.Tensor): Input of the collective.
recv_tensor (torch.Tensor): Output of the collective, must have a size of ``comm.nranks * send_tensor.size()`` elements.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
"""
if _rank_not_in_comm(comm):
return
assert send_tensor.device != torch.device(
"cpu"
), "send tensor must be CUDA and dense"
assert recv_tensor.device != torch.device(
"cpu"
), "recv tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.allgather(
send_tensor.to_bagua_tensor().bagua_backend_tensor(),
recv_tensor.to_bagua_tensor().bagua_backend_tensor(),
)
comm.cuda_stream.synchronize()
def allgather_inplace(
tensor: torch.Tensor,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""The in-place version of :func:`allgather`."""
if _rank_not_in_comm(comm):
return
assert tensor.device != torch.device("cpu"), "input tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.allgather_inplace(tensor.to_bagua_tensor().bagua_backend_tensor())
comm.cuda_stream.synchronize()
def gather(
send_tensor: torch.Tensor,
recv_tensor: torch.Tensor,
dst: int,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""Gathers send tensors from all processes associated with the communicator to :attr:`recv_tensor` in a single process.
Args:
send_tensor: Input of the collective.
recv_tensor: Output of the collective, must have a size of ``comm.nranks * send_tensor.size()`` elements.
dst: Destination rank.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
"""
if _rank_not_in_comm(comm):
return
assert send_tensor.device != torch.device(
"cpu"
), "send tensor must be CUDA and dense"
assert recv_tensor.device != torch.device(
"cpu"
), "recv tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.gather(
send_tensor.to_bagua_tensor().bagua_backend_tensor(),
recv_tensor.to_bagua_tensor().bagua_backend_tensor(),
dst,
)
comm.cuda_stream.synchronize()
def gather_inplace(
tensor: torch.Tensor,
count: int,
dst: int,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""The in-place version of :func:`gather`.
Args:
tensor: Input and output of the collective, On the :attr:`dst` rank, it
must have a size of ``comm.nranks * count`` elements. On non-dst ranks, its size must
be equal to :attr:``count``.
count: The per-rank data count to gather.
dst: Destination rank.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
"""
if _rank_not_in_comm(comm):
return
assert tensor.device != torch.device("cpu"), "input tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.gather_inplace(tensor.to_bagua_tensor().bagua_backend_tensor(), count, dst)
comm.cuda_stream.synchronize()
def scatter(
send_tensor: torch.Tensor,
recv_tensor: torch.Tensor,
src: int,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""Scatters send tensor to all processes associated with the communicator.
Args:
send_tensor: Input of the collective, must have a size of ``comm.nranks * recv_tensor.size()`` elements.
recv_tensor: Output of the collective.
src: Source rank.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
"""
if _rank_not_in_comm(comm):
return
assert send_tensor.device != torch.device(
"cpu"
), "send tensor must be CUDA and dense"
assert recv_tensor.device != torch.device(
"cpu"
), "recv tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.scatter(
send_tensor.to_bagua_tensor().bagua_backend_tensor(),
recv_tensor.to_bagua_tensor().bagua_backend_tensor(),
src,
)
comm.cuda_stream.synchronize()
def scatter_inplace(
tensor: torch.Tensor,
count: int,
src: int,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""The in-place version of :func:`scatter`.
Args:
tensor: Input and output of the collective, On the :attr:`src` rank,
it must have a size of ``comm.nranks * count`` elements. On non-src ranks,
its size must be equal to :attr:`count`.
count: The per-rank data count to scatter.
src: Source rank.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
"""
if _rank_not_in_comm(comm):
return
assert tensor.device != torch.device("cpu"), "input tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.scatter_inplace(
tensor.to_bagua_tensor().bagua_backend_tensor(), count, src
)
comm.cuda_stream.synchronize()
def reduce_scatter(
send_tensor: torch.Tensor,
recv_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""Reduces, then scatters :attr:`send_tensor` to all processes associated with the communicator.
Args:
send_tensor (torch.Tensor): Input of the collective, must have a size of ``comm.nranks * recv_tensor.size()`` elements.
recv_tensor (torch.Tensor): Output of the collective.
op (ReduceOp, optional): One of the values from :class:`ReduceOp` enum. Specifies an operation used for element-wise reductions.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
"""
if _rank_not_in_comm(comm):
return
assert send_tensor.device != torch.device(
"cpu"
), "send tensor must be CUDA and dense"
assert recv_tensor.device != torch.device(
"cpu"
), "recv tensor must be CUDA and dense"
if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()
event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)
with torch.cuda.stream(comm.cuda_stream):
comm.reduce_scatter(
send_tensor.to_bagua_tensor().bagua_backend_tensor(),
recv_tensor.to_bagua_tensor().bagua_backend_tensor(),
int(op),
)
comm.cuda_stream.synchronize()
def reduce_scatter_inplace(
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""The in-place version of :func:`reduce_scatter`.
Args:
tensor (torch.Tensor): Input and output of the collective, the size must be divisible by ``comm.nranks``.
op (ReduceOp, optional): One of the values from :class:`ReduceOp` enum. Specifies an operation used for element-wise reductions.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.