This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 361
/
device_mesh.py
2506 lines (2116 loc) · 97.2 KB
/
device_mesh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# pylint: disable=protected-access
"""The device mesh runtime that manages buffers and runs computation
distributedly.
The hierarchy of classes defined in this file:
DeviceCluster (the whole ray cluster)
|
PhysicalDeviceMeshGroup (multiple device meshes)
|
PhysicalDeviceMesh (one device mesh)
|
MeshHostWorker (one host in a device mesh)
Besides, we have two additional classes: VirtualPhysicalMesh and
LogicalDeviceMesh. They are only used during compilation time. They are used to
manipulate meshes flexibly without allocating real resources during compilation
time.
"""
from abc import ABC, abstractmethod
import asyncio
from collections import defaultdict, namedtuple
from collections.abc import Iterable
import logging
from operator import attrgetter
import os
import pickle
import shutil
import threading
import time
from typing import Any, List, Union, Sequence, Tuple, Optional
from jax import core, xla, device_put
from jax._src.api import ShapeDtypeStruct
from jax._src.lib import xla_bridge as xb, xla_extension as xe
from jax._src.tree_util import tree_leaves
from jax.abstract_arrays import array_types
from jax.core import ShapedArray
from jax.interpreters import pxla
from jax.interpreters.pxla import (ShardingSpec, _hashable_index,
ShardedDeviceArray, Index)
from jax.lib import xla_client
import jax.numpy as jnp
import numpy as np
import ray
from ray.util.placement_group import remove_placement_group
from alpa import mesh_profiling
import alpa.collective as col
from alpa.global_env import global_config
from alpa.monkey_patch import set_override_backend
from alpa.shard_parallel.auto_sharding import (LogicalDeviceMesh)
from alpa.parallel_plan import PlacementSpec
from alpa.timer import timers, tracer
from alpa.util import (benchmark_func, list_gpu_info, OrderedSet,
update_jax_platform, is_ray_node_resource,
try_import_ray_worker, create_placement_group,
get_bundle_idx, retrieve_placement_group, get_bundle2ip,
check_server_port)
ray_worker = try_import_ray_worker()
if global_config.backend == "gpu" and global_config.has_cuda:
from alpa.collective import worker_nccl_util
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ReshardingTileSpec = namedtuple("ReshardingTileSpec",
["offset", "rank", "gpu_idx"])
ReshardingSendSpec = namedtuple("ReshardingSendSpec",
["device_id", "tile_spec"])
ReshardingSendTask = namedtuple("ReshardingSendTask",
["tile_specs", "group_name"])
ReshardingRecvSpec = namedtuple("ReshardingRecvSpec",
["device_id", "shape", "dtype", "tile_specs"])
ReshardingRecvTask = namedtuple("ReshardingRecvTask",
["recv_specs", "group_name"])
ReshardingBroadcastSpec = namedtuple("ReshardingBroadcastSpec", [
"comm_key", "world_size", "devices_ids", "devices_global_rank",
"tensor_slices", "recv_tile_shape", "dtype"
])
ReshardingBroadcastTask = namedtuple("ReshardingBroadcastTask",
["broadcast_specs", "group_name"])
########################################
# Ray Workers
########################################
class DaemonMoveWorker:
"""
A ray actor that moves local checkpoint into the shared
filesystem in the background.
"""
def move(self, from_dir: str, to_dir: str):
os.makedirs(to_dir, exist_ok=True)
for file in os.listdir(from_dir):
from_path = os.path.join(from_dir, file)
to_path = os.path.join(to_dir, file)
shutil.move(from_path, to_path)
def sync(self):
"""Noop function used to synchronize."""
class MeshHostWorker:
"""
A ray actor that manages the xla computation and buffers on a single host.
"""
def __init__(self, server_address: str, num_hosts: int, host_id: int,
mesh_id: int, move_worker: DaemonMoveWorker,
runtime_random_seed: int, worker_global_config: dict):
self.num_hosts = num_hosts
self.host_id = host_id
self.mesh_id = mesh_id
self.move_worker = move_worker
self.distributed_client = (
xla_client._xla.get_distributed_runtime_client(
server_address, host_id, use_coordination_service=False))
logger.debug(
f"{host_id}: Trying to connect to xla runtime at {server_address}")
self.distributed_client.connect()
logger.debug(
f"{host_id}: Success to connect to xla runtime at {server_address}")
# Set global config to follow the driver
global_config.update_worker_config(worker_global_config)
if global_config.backend == "gpu":
self.backend = xla_client.make_gpu_client(self.distributed_client,
node_id=host_id)
else:
raise NotImplementedError(
f"backend {global_config.backend} is not supported")
# Monkey patch the backend
set_override_backend(self.backend)
self.local_devices = self.backend.local_devices()
self.num_devices = len(self.local_devices)
if global_config.enable_overlapping:
xe.set_num_device_on_host(self.num_devices)
self.buffers = {} # Dict[uuid -> Sequence[DeviceArray]]
self.executables = {} # Dict[uud -> MeshWorkerExecutable]
self.send_tasks = {} # Dict[uuid -> ReshardingSendTask]
self.recv_tasks = {} # Dict[uuid -> ReshardingRecvTask]
self.broadcast_tasks = {} # Dict[uuid -> BroadcastTask]
self.broadcast_communicators = {}
self.data_loaders = {} # Dict[uuid -> MeshWorkerDataLoader]
self.data_loader_iters = {} # Dict[uuid -> iterator]
self.set_runtime_random_seed(runtime_random_seed)
if global_config.pipeline_use_signal_send_recv:
print("Use signal send recv for debugging.")
self.signal_buffers = []
for d in self.local_devices:
jax_tensor = device_put(jnp.ones((1,), dtype=jnp.int8), d)
self.signal_buffers.append(
worker_nccl_util.to_signal_buffer(jax_tensor))
##### Buffer Related Functions #####
def put_buffers(self,
uuids: Union[int, Sequence[int]],
datas: Sequence[np.ndarray],
num_batch=1,
batch_dim=0):
assert len(datas) == self.num_devices
if not isinstance(uuids, Iterable):
uuids = [uuids]
assert len(uuids) == num_batch
if num_batch > 1:
split_datas = []
for data in datas:
split_buffers = np.split(data, num_batch, batch_dim)
split_datas.extend(split_buffers)
datas = split_datas
arys = [([None] * self.num_devices) for _ in range(num_batch)]
for i, data in enumerate(datas):
if data.dtype == np.int64:
data = data.astype(np.int32)
device_id, batch_id = divmod(i, num_batch)
arys[batch_id][device_id] = (self.backend.buffer_from_pyval(
data, self.local_devices[device_id]))
for uuid, ary in zip(uuids, arys):
self.buffers[uuid] = ary
def shard_and_put_non_zero_buffer(self, uuids: Union[Sequence[int], int],
shape: Sequence[int], dtype: np.dtype,
indices: Sequence, num_batch: int):
if isinstance(uuids, int):
uuids = [uuids]
assert len(uuids) == num_batch
assert len(indices) == self.num_devices * num_batch
arys = [([None] * self.num_devices) for _ in range(num_batch)]
for device_id in range(self.num_devices):
for b in range(num_batch):
shard_shape = []
idx = device_id * num_batch + b
for j, s in enumerate(indices[idx]):
filled_slice = s.indices(shape[j])
dim_size = len(range(*filled_slice))
shard_shape.append(dim_size)
arys[b][device_id] = (self.backend.buffer_from_pyval(
np.full(shard_shape, 1e-8, dtype),
self.local_devices[device_id]))
for uuid, ary in zip(uuids, arys):
self.buffers[uuid] = ary
def _get_buffers_with_local_ids(self, uuid: int, device_ids: Sequence[int]):
bufs = self.buffers[uuid]
# TODO(yonghao): sync communication events. Currently it's safe because
# we never get values immediately after a cross-mesh communication.
if device_ids is None:
return map(np.asarray, bufs)
elif not isinstance(device_ids, Iterable):
return np.asarray(bufs[device_ids])
return [np.asarray(bufs[device_id]) for device_id in device_ids]
def get_buffers(self,
uuids: Union[Sequence[int], int],
device_indices: Sequence[int] = None):
if not isinstance(uuids, Iterable):
return self._get_buffers_with_local_ids(uuids, device_indices)
if device_indices is not None:
assert len(uuids) == len(device_indices)
else:
device_indices = [None] * len(uuids)
return [
self._get_buffers_with_local_ids(uuid, local_ids)
for uuid, local_ids in zip(uuids, device_indices)
]
def delete_buffers(self, uuids: Union[Sequence[int], int]):
if isinstance(uuids, Iterable):
for uuid in uuids:
del self.buffers[uuid]
else:
del self.buffers[uuids]
def block_until_ready_buffers(self, uuids: Union[Sequence[int], int]):
# We have to block all buffers to avoid the last operation is
# cross-mesh resharding(not SPMD)
if isinstance(uuids, Iterable):
for uuid in uuids:
for buf in self.buffers[uuid]:
buf.block_until_ready()
else:
for buf in self.buffers[uuids]:
buf.block_until_ready()
def get_memory_allocated(self):
self.sync()
return max(d.memory_allocated() for d in self.local_devices)
def get_max_memory_allocated(self):
self.sync()
return max(d.max_memory_allocated() for d in self.local_devices)
def get_available_memory(self):
self.sync()
return min(d.available_memory() for d in self.local_devices)
def reset_memory_stats(self):
self.sync()
for device in self.local_devices:
device.clear_memory_stats()
##### Executable Related Functions #####
def put_executable(self, uuid: int,
executable_class: "MeshWorkerExecutable", *args):
self.executables[uuid] = executable_class(self, uuid, *args)
def delete_executable(self, uuid: int):
if uuid in self.executables:
del self.executables[uuid]
def run_executable(self, uuid: int, *args, **kwargs):
self.executables[uuid].execute_on_worker(*args, **kwargs)
def get_exec_hlo_text(self, uuid: int):
return self.executables[uuid].get_hlo_text()
def get_exec_total_allocation_size(self, uuid: int):
return self.executables[uuid].get_total_allocation_size()
def get_exec_grad_sync_channel_ids(self, uuid: int):
return self.executables[uuid].grad_sync_channel_ids
def set_runtime_random_seed(self, seed: int):
seed = seed + (self.mesh_id << 20 if self.mesh_id else 0)
for d in self.local_devices:
d.set_seed(seed)
##### Serialization Related Functions #####
def sync_move_worker(self):
ray.get(self.move_worker.sync.remote())
def save_array(self, ckpt_dir: str, local_cache_dir: Union[str, None],
uuid: int, device_ids: Sequence[int],
shard_indices: Sequence[Index], global_shape: Sequence[int]):
assert uuid in self.buffers
array_buffers = self.buffers[uuid]
shard_names = [
f"shard_{self.host_id}.{i}" for i in range(len(device_ids))
]
metadata = {
"global_shape": global_shape,
"dtype": self.buffers[uuid][0].dtype,
"shard_names": shard_names,
"shard_indices": shard_indices,
}
# create directories if not exist
os.makedirs(ckpt_dir, exist_ok=True)
if local_cache_dir is not None:
os.makedirs(local_cache_dir, exist_ok=True)
save_dir = local_cache_dir
else:
save_dir = ckpt_dir
for shard_name, device_id in zip(shard_names, device_ids):
with open(os.path.join(save_dir, shard_name), "wb") as datafile:
np.save(datafile, array_buffers[device_id])
with open(os.path.join(save_dir, f"metadata_{self.host_id}"),
"wb") as metafile:
pickle.dump(metadata, metafile)
# move data
if local_cache_dir is not None:
self.move_worker.move.remote(local_cache_dir, ckpt_dir)
def load_array(self, ckpt_dir: str, uuid: Sequence[int],
device_ids: Sequence[int], shard_indices: Sequence[Index]):
metadatas = list(
filter(lambda fname: fname.startswith("metadata"),
os.listdir(ckpt_dir)))
# pylint: disable=import-outside-toplevel
from alpa.serialization import load_sharded_array
entire_arr = load_sharded_array(ckpt_dir, metadatas)
array_buffers = [None] * self.num_devices
for index, device_id in zip(shard_indices, device_ids):
data = entire_arr[index]
if data.dtype == np.int64:
data = data.astype(np.int32)
array_buffers[device_id] = (self.backend.buffer_from_pyval(
data, self.local_devices[device_id]))
self.buffers[uuid] = array_buffers
##### Data loader Related Functions #####
def put_data_loader(self, uuid: int, *args):
# pylint: disable=import-outside-toplevel
from alpa.data_loader import MeshWorkerDataLoader
self.data_loaders[uuid] = MeshWorkerDataLoader(self, *args)
def data_loader_iter(self, uuid: int):
self.data_loader_iters[uuid] = iter(self.data_loaders[uuid])
def data_loader_next(self, uuid: int):
next(self.data_loader_iters[uuid])
def delete_data_loader(self, uuid: int):
del self.data_loaders[uuid]
##### Cross Mesh Resharding Related Functions #####
@staticmethod
def init_collective_group(world_size, rank, backend, group_name):
"""Initialize the collective group eagerly."""
col.init_collective_group(world_size,
rank,
backend=backend,
group_name=group_name)
@staticmethod
def generate_nccl_uid(group_name):
"""Generate the NCCL unique ID in advance."""
g = col.check_and_get_group(group_name)
uid = g.generate_nccl_uid()
return uid
@staticmethod
def init_p2p_communicator(group_name, my_rank, my_gpu_idx, peer_rank,
peer_gpu_idx, nccl_uid):
"""Initialize the P2P communicator from within the mesh workers."""
assert col.is_group_initialized(group_name)
assert col.get_rank(group_name) == my_rank
g = col.check_and_get_group(group_name)
g.create_p2p_communicator(my_gpu_idx, peer_rank, peer_gpu_idx, nccl_uid)
@staticmethod
def init_broadcast_communicator(group_name, comm_key, world_size,
device_ids, devices_global_rank, nccl_uid):
"""Initialize the P2P communicator from within the mesh workers."""
assert col.is_group_initialized(group_name)
g = col.check_and_get_group(group_name)
g.create_nccl_broadcast_communicator(comm_key, world_size, device_ids,
devices_global_rank, nccl_uid)
@staticmethod
def destroy_collective_group(group_name: str = "default"):
col.destroy_collective_group(group_name)
def create_and_set_cross_mesh_communicators(self, world_size, rank, backend,
group_name, key):
"""Create collective communicators for the cross mesh group."""
if not col.is_group_initialized(group_name):
self.init_collective_group(world_size, rank, backend, group_name)
g = col.check_and_get_group(group_name)
devices = list(range(self.num_devices))
g.create_and_set_xla_communicators(devices, key)
def put_resharding_send_task(self, uuid, tasks, group_name):
self.send_tasks[uuid] = ReshardingSendTask(tile_specs=tasks,
group_name=group_name)
def put_resharding_recv_task(self, uuid, tasks, group_name):
self.recv_tasks[uuid] = ReshardingRecvTask(recv_specs=tasks,
group_name=group_name)
def run_resharding_send_task(self, uuid, ary_uuid):
task: ReshardingSendTask = self.send_tasks[uuid]
group_name = task.group_name
if global_config.enable_overlapping:
col.wait_events(group_name, [ary_uuid], self.num_devices, True)
for send_tile_spec in task.tile_specs:
send_tile_spec: ReshardingSendSpec
self.send_tile(ary_uuid, send_tile_spec.device_id,
send_tile_spec.tile_spec.offset,
send_tile_spec.tile_spec.rank,
send_tile_spec.tile_spec.gpu_idx, task.group_name)
def run_resharding_recv_task(self, uuid, ary_uuid, set_empty_buffer=True):
task: ReshardingRecvTask = self.recv_tasks[uuid]
group_name = task.group_name
if set_empty_buffer and ary_uuid not in self.buffers:
assert not global_config.enable_overlapping, "Unsupported."
self.buffers[ary_uuid] = [None] * self.num_devices
if global_config.enable_overlapping:
col.wait_events(group_name, [ary_uuid], self.num_devices, False)
buffers = self.buffers[ary_uuid]
for recv_spec in task.recv_specs:
recv_spec: ReshardingRecvSpec
device_id = recv_spec.device_id
if set_empty_buffer:
buffers[device_id] = self.backend.buffer_from_pyval(
np.full(recv_spec.shape, 1e-8, recv_spec.dtype),
self.local_devices[device_id])
for recv_tile_spec in recv_spec.tile_specs:
recv_tile_spec: ReshardingTileSpec
self.recv_tile(ary_uuid, device_id, recv_tile_spec.offset,
recv_tile_spec.rank, recv_tile_spec.gpu_idx,
task.group_name)
if global_config.enable_overlapping:
col.record_events(group_name, [ary_uuid], self.num_devices, False)
def send_tile(self, uuid: int, device_id: int, offset: Sequence[slice],
dst_rank: int, dst_gpu_idx: int, group_name: str):
if global_config.pipeline_use_signal_send_recv:
signal = self.signal_buffers[device_id]
col.send_multigpu(signal,
dst_rank,
dst_gpu_idx,
group_name,
start_pos=0,
n_elements=1)
else:
worker_nccl_util.send_tile(self, uuid, device_id, offset, dst_rank,
dst_gpu_idx, group_name)
def recv_tile(self, uuid: int, device_id: int,
indices_in_dst_tile: Sequence[slice], src_rank: int,
src_gpu_idx: int, group_name: str):
if uuid not in self.buffers:
raise RuntimeError("Buffer has not been created.")
if global_config.pipeline_use_signal_send_recv:
signal = self.signal_buffers[device_id]
col.recv_multigpu(signal,
src_rank,
src_gpu_idx,
group_name,
start_pos=0,
n_elements=1)
else:
worker_nccl_util.recv_tile(self, uuid, device_id,
indices_in_dst_tile, src_rank,
src_gpu_idx, group_name)
def put_resharding_broadcast_task(self, uuid, tasks, group_name):
self.broadcast_tasks[uuid] = ReshardingBroadcastTask(
broadcast_specs=tasks, group_name=group_name)
def run_resharding_broadcast_task(self,
uuid,
ary_uuid,
set_empty_buffer=True):
task: ReshardingBroadcastTask = self.broadcast_tasks[uuid]
group_name = task.group_name
broadcast_specs = task.broadcast_specs
if set_empty_buffer and ary_uuid not in self.buffers:
assert not global_config.enable_overlapping, "Unsupported."
picked_spec = list(broadcast_specs.values())[0]
shape = picked_spec.recv_tile_shape
dtype = picked_spec.dtype
self.buffers[ary_uuid] = [
self.backend.buffer_from_pyval(np.full(shape, 1e-8, dtype),
self.local_devices[device_id])
for device_id in range(self.num_devices)
]
has_recv = False
for group_idx in broadcast_specs:
broadcast_spec: ReshardingBroadcastSpec = broadcast_specs[group_idx]
is_send = broadcast_spec.devices_global_rank[0] == 0
has_recv = has_recv or not is_send
if global_config.enable_overlapping:
col.wait_events(group_name, [ary_uuid], self.num_devices,
is_send)
worker_nccl_util.broadcast(self, ary_uuid, broadcast_spec.comm_key,
broadcast_spec.world_size,
broadcast_spec.devices_ids,
broadcast_spec.devices_global_rank,
broadcast_spec.tensor_slices,
task.group_name)
if global_config.enable_overlapping and has_recv:
col.record_events(group_name, [ary_uuid], self.num_devices, False)
##### Profiling and Debugging Related Functions #####
def profile_hlo_ops(self, op_infos: Sequence[Any], cache_filename: str,
single_timeout: float):
num_devices = self.num_hosts * len(self.local_devices)
return mesh_profiling.profile_hlo_ops(op_infos, self.backend,
self.local_devices, self.host_id,
num_devices, cache_filename,
single_timeout)
def profile_executable_with_dummy_inputs(self, uuid: int, **kwargs):
return self.executables[uuid].profile_with_dummy_inputs(
self.backend, self.local_devices, **kwargs)
def profile_resharding_send_task(self,
uuid,
buf_uuids,
warmup=1,
repeat=3,
number=3,
sync=False):
# TODO(yonghao): the sync function should be carefully reconsidered
def run_fn():
self.run_resharding_send_task(uuid, buf_uuids)
sync_fn = self.sync if sync else None
costs = benchmark_func(run_fn, sync_fn, warmup, repeat, number)
return np.mean(costs)
def profile_resharding_recv_task(self,
uuid,
buf_uuids,
warmup=1,
repeat=3,
number=3,
sync=False):
set_empty_buffer = True
def run_fn():
nonlocal set_empty_buffer
self.run_resharding_recv_task(uuid, buf_uuids, set_empty_buffer)
set_empty_buffer = False
sync_fn = self.sync if sync else None
costs = benchmark_func(run_fn, sync_fn, warmup, repeat, number)
return np.mean(costs)
@staticmethod
def get_timer(name: str):
return timers(name)
@staticmethod
def reset_timer(name: str):
timers(name).reset()
@staticmethod
def get_tracer():
return tracer
def get_live_buffer_uuids(self):
return list(self.buffers.keys())
##### Other Functions #####
def sync(self, sync_all_devices=False):
# We sync one device instead of all for smaller runtime overhead.
# This is correct because of SPMD.
if sync_all_devices:
for device in self.local_devices:
device.synchronize_all_activity()
else:
self.local_devices[0].synchronize_all_activity()
def sync_all(self):
for device in self.local_devices:
device.synchronize_all_activity()
@staticmethod
def check_alive():
return True
def shutdown(self):
self.sync()
self.buffers.clear()
self.executables.clear()
self.distributed_client.shutdown()
# sync & shutdown DaemonMoveWorker
self.sync_move_worker()
ray.kill(self.move_worker)
self.move_worker = None
########################################
# DeviceMeshs
########################################
class PhysicalDeviceMesh(ABC):
"""The base class of physical device mesh.
A physical device mesh is a 2-dimensional mesh that runs SPMD computation on
all devices in the mesh.
"""
num_hosts: int
num_devices_per_host: int
mesh_id: int
operation_executables: dict
one_replica_ids: dict
def get_signature(self) -> str:
"""Return a signature string that contains the mesh shape and GPU
model."""
gpu_type = list_gpu_info()
gpu_name = gpu_type.split("\n")[0].split(" (UUID:")[0][7:]
ret = f"{self.num_hosts},{self.num_devices_per_host},{gpu_name}"
ret = ret.replace(" ", "-")
return ret
def _compute_one_replica_ids(self, indices, aval_shape, sharding_spec):
# Tuple (aval_shape, sharding_spec) is 1-1 mapped to indices
# used to compute one_replica_ids
if (aval_shape, sharding_spec) in self.one_replica_ids:
return self.one_replica_ids[(aval_shape, sharding_spec)]
one_replica_indices = []
one_replica_host_local_ids = []
seen_index_hashes = set()
for i, index in enumerate(indices):
hashed_index = _hashable_index(index)
if hashed_index not in seen_index_hashes:
one_replica_indices.append(i)
one_replica_host_local_ids.append(
divmod(i, self.num_devices_per_host))
seen_index_hashes.add(hashed_index)
self.one_replica_ids[(
aval_shape,
sharding_spec)] = one_replica_indices, one_replica_host_local_ids
return one_replica_indices, one_replica_host_local_ids
@property
def shape(self):
return self.num_hosts, self.num_devices_per_host
@property
def num_devices(self):
"""Return the total number of GPUs on this mesh."""
return self.num_hosts * self.num_devices_per_host
##### Logical Mesh Related Functions #####
def get_logical_mesh(self,
mesh_shape: Optional[Sequence[int]] = None,
mesh_alpha: Optional[float] = None,
mesh_beta: Optional[float] = None,
mesh_topology: Optional[str] = None,
intra_host_bandwidth: Optional[float] = None,
inter_host_bandwidth: Optional[float] = None):
"""
Return a logical mesh and parameters of the alpha-beta communication
cost model. The logical view is used for auto-sharding.
"""
if mesh_shape is None:
mesh_shape = (self.num_hosts, self.num_devices_per_host)
id_mesh = np.arange(self.num_devices).reshape(mesh_shape)
if mesh_topology is None:
# Use the provided mesh_alpha and mesh_beta
mesh_alpha = mesh_alpha or (1, 1)
mesh_beta = mesh_beta or (1, 0.1)
elif mesh_topology == "tree":
# Derive mesh_alpha and mesh_beta from topology,
# intra_host_bandwidth and inter_host_bandwidth
assert mesh_alpha is None
assert mesh_beta is None
mesh_alpha = [1] * 2
mesh_beta = [None] * 2
host_ids = np.tile(
np.arange(self.num_hosts).reshape(-1, 1),
self.num_devices_per_host)
host_ids = host_ids.reshape(mesh_shape)
# Compute bandwidth of doing communication along dim 0.
# 1. Compute the number of links between each host pairs.
# Assume using ring-based algorithms.
host_link_ct = defaultdict(int)
for j in range(mesh_shape[1]):
for i in range(mesh_shape[0]):
left = host_ids[i][j]
right = host_ids[(i + 1) % mesh_shape[0]][j]
if left != right:
if left > right:
left, right = right, left
host_link_ct[(left, right)] += 1
j = 0
# 2. Bandwidth between two hosts
# = total_bandwidth / number_of_links.
# Bandwdith along a communication dimension
# = min bandwidth of all links.
bandwidth = intra_host_bandwidth
for i in range(mesh_shape[0]):
left = host_ids[i][j]
right = host_ids[(i + 1) % mesh_shape[0]][j]
if left != right:
if left > right:
left, right = right, left
bandwidth = min(
bandwidth,
inter_host_bandwidth / host_link_ct[(left, right)])
mesh_beta[0] = 1 / bandwidth
# Compute bandwidth of doing communication along dim 1.
host_link_ct = defaultdict(int)
for i in range(mesh_shape[0]):
for j in range(mesh_shape[1]):
left = host_ids[i][j]
right = host_ids[i][(j + 1) % mesh_shape[1]]
if left != right:
if left > right:
left, right = right, left
host_link_ct[(left, right)] += 1
i = 0
bandwidth = intra_host_bandwidth
for j in range(mesh_shape[1]):
left = host_ids[i][j]
right = host_ids[i][(j + 1) % mesh_shape[1]]
if left != right:
if left > right:
left, right = right, left
bandwidth = min(
bandwidth,
inter_host_bandwidth / host_link_ct[(left, right)])
mesh_beta[1] = 1 / bandwidth
return LogicalDeviceMesh(self, id_mesh, mesh_alpha, mesh_beta)
##### Executable Related Functions #####
@abstractmethod
def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]],
donated_invars: Sequence[bool],
batch_invars: Sequence[bool], num_micro_batches: int,
args: Sequence[Any]):
"""Shard high-level arguments as low-level buffers."""
raise NotImplementedError()
@abstractmethod
def shard_args_to_arrays(self, avals: Sequence[ShapedArray],
shard_indices: Sequence[Sequence[Index]],
sharding_specs: Sequence[ShardingSpec],
args: Sequence[Any]):
"""Shard arguments (np.ndarray) as distributed arrays."""
raise NotImplementedError()
def shard_args_to_arrays_ps(self, placement_specs: PlacementSpec,
args: Sequence[Any]):
"""
Shard arguments (np.ndarray) as distributed arrays according to
PlacementSpec.
"""
avals = tuple(x.aval for x in placement_specs)
assert all(
len(x.mesh_ids) == 1 and x.mesh_ids[0] == self.mesh_id
for x in placement_specs)
specs = tuple(x.sharding_specs[0] for x in placement_specs)
indices = tuple(
pxla.spec_to_indices(aval.shape, spec)
for aval, spec in zip(avals, specs))
return self.shard_args_to_arrays(avals, indices, specs, args)
@abstractmethod
def get_outputs_handler(self, avals: Sequence[ShapedArray],
sharding_specs: Sequence[ShardingSpec]):
"""
Get a function that wraps low-level buffers to high-level output arrays.
"""
raise NotImplementedError()
@abstractmethod
def set_runtime_random_seed(self, seed: int):
raise NotImplementedError()
##### Profiling Related Functions #####
@abstractmethod
def get_remote_timer(self, timer_name: str):
raise NotImplementedError()
@abstractmethod
def reset_remote_timer(self, timer_name: str):
raise NotImplementedError()
@abstractmethod
def get_remote_tracer(self):
raise NotImplementedError()
@abstractmethod
def get_memory_allocated(self):
raise NotImplementedError()
@abstractmethod
def get_max_memory_allocated(self):
raise NotImplementedError()
@abstractmethod
def get_available_memory(self):
raise NotImplementedError()
@abstractmethod
def reset_memory_stats(self):
raise NotImplementedError()
##### Other Functions #####
@abstractmethod
def sync_workers(self):
"""Sync device activities on all workers."""
raise NotImplementedError()
@abstractmethod
def shutdown(self, forced=False):
"""Shut down the mesh."""
raise NotImplementedError()
class LocalPhysicalDeviceMesh(PhysicalDeviceMesh):
"""
A single-host physical device mesh to run computation on local devices.
It uses the native XLA runtime.
"""
def __init__(self, devices: Sequence["Device"] = None):
self.devices = devices if devices is not None else xb.local_devices()
self.num_hosts = 1
self.num_devices_per_host = len(self.devices)
self.mesh_id = -1
self.device_strs = []
self.operation_executables = {}
self.one_replica_ids = {}
self.backend = xb.get_backend(global_config.backend)
self.set_runtime_random_seed(global_config.runtime_random_seed)
##### Executable Related Functions #####
def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]],
donated_invars: Sequence[bool],
batch_invars: Sequence[bool], num_micro_batches: int,
args: Sequence[Any]):
bufs = []
for arg, indices, donated, is_batch_var in zip(args, shard_indices,
donated_invars,
batch_invars):
if is_batch_var:
micro_batches = jnp.split(arg, num_micro_batches)
bufs.append([
pxla._shard_arg(x, self.devices, indices, None)
for x in micro_batches
])
else:
if (isinstance(arg, pxla.ShardedDeviceArray) and
arg.indices == indices):
bufs.append(arg.device_buffers)
else:
bufs.append(
pxla._shard_arg(arg, self.devices, indices, None))
if isinstance(arg, xe.DeviceArray) and donated:
arg.delete()
return bufs
def shard_args_to_arrays(self, avals: Sequence[ShapedArray],
shard_indices: Sequence[Sequence[Index]],
sharding_specs: Sequence[ShardingSpec],
args: Sequence[Any]):
arrays = []
for i in range(len(avals)):
if global_config.use_dummy_value_for_benchmarking:
args[i] = np.full(avals[i].shape, 1e-8, avals[i].dtype)
shards = [
args[i][shard_indices[i][k]] for k in range(len(self.devices))
]
buffers = [device_put(x, d) for x, d in zip(shards, self.devices)]
arrays.append(
pxla._ShardedDeviceArray(avals[i], sharding_specs[i], buffers,
shard_indices[i]))
return arrays
def get_outputs_handler(self, avals: Sequence[ShapedArray],
sharding_specs: Sequence[ShardingSpec]):
pmap_specs = pxla._get_pmap_sharding(np.arange(self.num_devices),
sharding_specs)
outs_handler = pxla.local_avals_to_results_handler(avals, pmap_specs)
return outs_handler
def set_runtime_random_seed(self, seed: int):
for d in self.devices:
if d is not None:
d.set_seed(seed)
##### Profiling Related Functions #####
def get_remote_timer(self, timer_name: str):
return timers(timer_name)
def reset_remote_timer(self, timer_name: str):
timers(timer_name).reset()
def get_remote_tracer(self):
return tracer
def get_memory_allocated(self):
return max(d.memory_allocated() for d in self.devices)
def get_max_memory_allocated(self):
return max(d.max_memory_allocated() for d in self.devices)
def get_available_memory(self):
return min(device.available_memory() for device in self.devices)
def reset_memory_stats(self):
for device in self.devices:
device.clear_memory_stats()
##### Other Functions #####
def sync_workers(self):
# We sync one device instead of all for smaller runtime overhead.
# This is correct because of SPMD.
self.devices[0].synchronize_all_activity()
def shutdown(self, forced=False):
self.sync_workers()
self.operation_executables.clear()
def device_id_to_str(host_ip, device_id, device_type="gpu"):
"""Convert device id (int) to a canonical device string."""
return f"{host_ip}:{device_type}:{device_id}"
# Used ports for XLA distributed runtime servers.
used_port_set = set((None,))
class DistributedPhysicalDeviceMesh(PhysicalDeviceMesh):
"""
A multi-host physical device mesh to run computation distributedly.
It uses ray actors and the distributed XLA runtime.
"""
def __init__(self,
host_ids: Sequence[int],
host_info: Sequence[dict],
num_devices_per_host: int,
parent: Optional["VirtualPhysicalMesh"] = None,
devices: Optional[Sequence[Sequence[int]]] = None,
mesh_id: Optional[int] = None,
namespace: Optional[str] = None):
# host_ids are the indices of hosts in the global DeviceCluster
self.host_ids = host_ids
self.host_info = host_info
self.num_hosts = len(host_ids)
self.num_devices_per_host = num_devices_per_host
self.parent = parent
self.mesh_id = mesh_id
self.workers = None