-
Notifications
You must be signed in to change notification settings - Fork 353
/
computation.py
1252 lines (1105 loc) · 51.2 KB
/
computation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Pipeline computation definitions."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import logging
from typing import Sequence, Any, Dict, Optional
import jax
from jax import jit
from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe
from jax._src.util import partial, safe_map
from jax._src import dispatch
from jax.core import (Atom, Var, JaxprEqn, Jaxpr, ClosedJaxpr, DropVar, Literal,
jaxpr_as_fun, new_jaxpr_eqn, gensym, named_call_p,
ShapedArray, get_aval, raise_to_shaped)
from jax.interpreters import pxla
from jax.interpreters.partial_eval import remat_call_p
import numpy as np
from alpa.mesh_executable import PartialGradAccMeshDriverExecutable
from alpa.parallel_plan import StagePlan
from alpa.pipeline_parallel.primitive_def import (mark_hook_jaxpreqn,
pipeline_p,
mark_pipeline_jaxpreqn)
from alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass,
run_spmd_partitioner_pass,
get_input_output_sharding_specs,
hlo_sharding_to_sharding_spec,
AutoShardingOption)
from alpa.global_env import global_config
from alpa.util import (OrderedSet, clone_jaxpr, get_compile_options,
jaxpr_to_hlo_module, setup_computation_alias,
compile_dummy_zero_constant, get_var_mapping)
# pylint: disable=redefined-builtin
unsafe_map, map = map, safe_map # type: ignore
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@dataclass
class PipelineComputation(ABC):
"""
Base class of pipeline computations.
Attributes:
name (str): The name of the pipeline computation.
invars (Sequence[Var]): The list of input variables, corresponding to
the order of the runnable inputs.
outvars (Sequence[Var]): The list of output variables, corresponding to
the order of the runnable outputs.
"""
name: str
invars: Sequence[Var] = field(default_factory=list)
outvars: Sequence[Var] = field(default_factory=list)
@abstractmethod
def get_runnable(self, mesh=None):
"""Compile the computation and get the runnable."""
raise NotImplementedError()
@dataclass
class StrVarPipelineComputation:
"""Stringified computation with all Set/Dict have string keys."""
name: str
invars: Sequence[str]
outvars: Sequence[str]
@classmethod
def from_pipeline_computation(cls,
pipeline_computation: PipelineComputation):
"""Construct a StrVarPipelineComputation from a PipelineComputation."""
return cls(
name=pipeline_computation.name,
invars=[repr(var) for var in pipeline_computation.invars],
outvars=[repr(var) for var in pipeline_computation.outvars],
)
@dataclass
class JaxPipelineComputation(PipelineComputation):
"""
A pipeline computation defined by Jaxpr.
Attributes:
eqns (Sequence[JaxprEqn]): Jaxpr equations of the pipeline computation.
consts_dir: Dict[Atom, Any]: All the constants used in the pipeline
computation.
"""
eqns: Sequence[JaxprEqn] = field(default_factory=list)
consts_dir: Dict[Atom, Any] = field(default_factory=dict)
def closed_jaxpr(self) -> ClosedJaxpr:
"""
Get the closed Jaxpr of the pipeline computation.
Returns:
ClosedJaxpr: The result ClosedJaxpr.
"""
jaxpr = Jaxpr(
constvars=list(self.consts_dir.keys()),
invars=self.invars,
outvars=self.outvars,
eqns=self.eqns,
)
closed_jaxpr = ClosedJaxpr(jaxpr, list(self.consts_dir.values()))
return closed_jaxpr
def get_runnable(self, mesh=None):
"""Return a JIT callable of the pipeline computation."""
closed_jaxpr = self.closed_jaxpr()
return jit(jaxpr_as_fun(closed_jaxpr))
@classmethod
def from_closed_jaxpr(cls, name, closed_jaxpr: ClosedJaxpr):
"""Construct a JaxPipelineComputation from a Jaxpr."""
return cls(name=name,
invars=closed_jaxpr.jaxpr.invars,
outvars=closed_jaxpr.jaxpr.outvars,
eqns=closed_jaxpr.eqns,
consts_dir=dict(
zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts)))
@dataclass
class XlaPipelineComputation(PipelineComputation):
"""A pipeline computation defined by XLA HLO Module."""
hlo_module: xe.HloModule = None
@classmethod
def from_jax_pipeline_computation(
cls, jax_pipeline_computation: JaxPipelineComputation):
"""
Construct a XlaPipelineComputation from a JaxPipelineComputation.
Args:
jax_pipeline_computation (JaxPipelineComputation): the source
JaxPipelineComputation.
"""
closed_jaxpr = jax_pipeline_computation.closed_jaxpr()
backend = xb.get_backend("gpu")
name = f"pipeline_computation_{jax_pipeline_computation.name}"
hlo_module = jaxpr_to_hlo_module(name, closed_jaxpr, None, backend)
return cls(
name=jax_pipeline_computation.name,
hlo_module=hlo_module,
invars=jax_pipeline_computation.invars,
outvars=jax_pipeline_computation.outvars,
)
def get_runnable(self, mesh=None):
"""Return a callable of the pipeline computation."""
out_avals = [var.aval for var in self.outvars]
tuple_args = len(
self.invars) > 100 # pass long arg lists as tuple for TPU
backend = "gpu"
backend = xb.get_backend(backend)
device = backend.get_default_device_assignment(1)[0]
options = get_compile_options(
num_replicas=1,
num_partitions=1,
device_assignment=(device.id,) if device else None,
use_spmd_partitioning=False,
parameter_is_tupled_arguments=tuple_args,
build_random_seed=global_config.compile_random_seed,
)
xla_computation = xc.XlaComputation(
self.hlo_module.as_serialized_hlo_module_proto())
compiled = backend.compile(xla_computation, compile_options=options)
result_handlers = map(partial(dispatch.aval_to_result_handler, device),
out_avals)
buffer_counts = (None if len(out_avals) == 1 else [
dispatch.aval_to_num_buffers(aval) for aval in out_avals
])
kept_var_idx = range(len(self.invars))
return partial(
dispatch._execute_compiled, # pylint: disable=protected-access
self.name,
compiled,
buffer_counts,
result_handlers,
kept_var_idx)
def get_hlo_text(self):
"""Get the HLO text."""
return self.hlo_module.to_string()
@dataclass
class XlaShardedPipelineComputation(PipelineComputation):
"""
A pipeline computation defined by XLA HLO Module.
The XLA HLO is annotated by sharding spec.
"""
sharding_annotated_module: xe.HloModule = None
donated_invars: Sequence[bool] = None
stage_plan: StagePlan = None
input_sharding_specs: Sequence[pxla.ShardingSpec] = None
output_sharding_specs: Sequence[pxla.ShardingSpec] = None
output_acc_grad_indices: Sequence[int] = None
donatables: OrderedSet[Var] = None
spmd_partitioned_hlo_module: xe.HloModule = None
@classmethod
def dummy_computation(cls, name, logical_mesh_shape, gensym_func):
"""Create a dummy computation."""
backend_name = "gpu"
backend = xb.get_backend(backend_name)
stage_plan = StagePlan(global_config.compile_random_seed,
logical_mesh_shape, 1, 1, AutoShardingOption(),
None, 0)
compiled = compile_dummy_zero_constant(backend,
np.prod(logical_mesh_shape))
sharding_annotated_module = compiled.hlo_modules()[0]
outvar = gensym_func(ShapedArray((), np.dtype(np.int32)))
return cls(
name=name,
sharding_annotated_module=sharding_annotated_module,
stage_plan=stage_plan,
donated_invars=[],
invars=[],
outvars=[outvar],
output_acc_grad_indices=[],
donatables=OrderedSet(),
)
@classmethod
def from_auto_sharded_computation(
cls,
*,
jax_pipeline_computation: JaxPipelineComputation,
sharding_annotated_module: xe.HloModule,
stage_plan: StagePlan,
donated_invars: Sequence[bool] = None,
acc_grad_outvars: Sequence[Var] = (),
donatables: OrderedSet[Var] = None):
"""Run auto-sharding optimizer on a Jax pipeline computation."""
if donatables is None:
donatables = OrderedSet()
if not donated_invars:
donated_invars = (False,) * len(jax_pipeline_computation.invars)
acc_grad_indices = [
out_idx
for out_idx, outvar in enumerate(jax_pipeline_computation.outvars)
if outvar in acc_grad_outvars
]
return cls(name=jax_pipeline_computation.name,
sharding_annotated_module=sharding_annotated_module,
stage_plan=stage_plan,
donated_invars=donated_invars,
invars=jax_pipeline_computation.invars,
outvars=jax_pipeline_computation.outvars,
output_acc_grad_indices=acc_grad_indices,
donatables=donatables)
def donate_intermediates(self, computation):
"""Donate intermediate variables."""
# get sharding annotated hlo module
hlo_module = computation.as_hlo_module()
donatable = OrderedSet(self.donatables)
# get sharding specs
hlo_module.infer_spmd_shardings()
avals = [var.aval for var in self.invars]
out_avals = [var.aval for var in self.outvars]
logical_mesh_shape = self.stage_plan.logical_mesh_shape
input_shardings = hlo_module.spmd_parameters_shardings()
input_sharding_specs = [
hlo_sharding_to_sharding_spec(proto_tuple, aval, logical_mesh_shape)
for (proto_tuple, aval) in zip(input_shardings, avals)
]
output_shardings = hlo_module.spmd_output_sharding()
output_sharding_specs = hlo_sharding_to_sharding_spec(
output_shardings, out_avals, logical_mesh_shape)
num_donated = np.count_nonzero(self.donated_invars)
donatable_outvars = OrderedSet(self.outvars[num_donated:])
donated_invars = []
donated_outvars = []
var_indices = dict(zip(self.outvars, range(len(self.outvars))))
var_indices.update(dict(zip(self.invars, range(len(self.invars)))))
for idx, invar in enumerate(self.invars):
if invar not in donatable:
# not donatable
continue
if self.donated_invars[idx]:
# already donated
continue
for outvar in donatable_outvars:
if (invar.aval.shape == outvar.aval.shape and
input_sharding_specs[var_indices[invar]]
== output_sharding_specs[var_indices[outvar]]):
donated_invars.append(invar)
donated_outvars.append(outvar)
donatable_outvars.discard(outvar)
break
# set alias
for invar, outvar in zip(donated_invars, donated_outvars):
invar_idx, outvar_idx = var_indices[invar], var_indices[outvar]
computation.setup_alias((outvar_idx,), invar_idx, ())
for invar in donated_invars:
self.donated_invars[var_indices[invar]] = True
def get_spmd_partitioned(self):
"""Run spmd partitioner to get the input/output sharding specs after
partitioning."""
if self.spmd_partitioned_hlo_module is not None:
return self.spmd_partitioned_hlo_module
stage_plan = self.stage_plan
logical_mesh_shape = stage_plan.logical_mesh_shape
hlo_module = self.sharding_annotated_module
setup_computation_alias(hlo_module, self.donated_invars)
num_devices = np.prod(logical_mesh_shape)
rewrite_for_grad_acc = len(self.output_acc_grad_indices) > 0
spmd_partitioned_hlo_module = run_spmd_partitioner_pass(
hlo_module,
num_devices,
rewrite_for_grad_acc=rewrite_for_grad_acc,
rewrite_grad_acc_indices=self.output_acc_grad_indices)
avals = [var.aval for var in self.invars]
out_avals = [var.aval for var in self.outvars]
input_sharding_specs, output_sharding_specs = (
get_input_output_sharding_specs(spmd_partitioned_hlo_module, avals,
out_avals, num_devices,
stage_plan.logical_mesh_shape))
self.input_sharding_specs = input_sharding_specs
self.output_sharding_specs = output_sharding_specs
# The run_spmd_partitioner_pass modifies hlo module in-place,
# so the old hlo module cannot be accessed anymore
self.sharding_annotated_module = None
self.spmd_partitioned_hlo_module = spmd_partitioned_hlo_module
return spmd_partitioned_hlo_module
def get_runnable(self, mesh=None):
"""Return a callable of the pipeline computation."""
if not mesh:
raise RuntimeError(
"`XlaShardedPipelineComputation` requires a mesh.")
hlo_module = self.get_spmd_partitioned()
avals = [var.aval for var in self.invars]
out_avals = [var.aval for var in self.outvars]
mesh_executable = PartialGradAccMeshDriverExecutable(
mesh, hlo_module, self.stage_plan, avals, out_avals,
self.donated_invars, self.output_acc_grad_indices)
return mesh_executable.get_driver_callable()
def get_hlo_text(self):
"""Get the HLO text."""
assert self.sharding_annotated_module is not None
return self.sharding_annotated_module.to_string()
def slice_closed_jaxpr_by_full_pipeline_marks(
closed_jaxpr: ClosedJaxpr) -> Sequence[JaxPipelineComputation]:
"""Slice a closed jaxpr into multiple JaxPipelineComputation by full
pipeline markers."""
global_consts_dir = dict(
zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts))
result_computations = []
current_computation = None
for eqn in closed_jaxpr.jaxpr.eqns:
if eqn.primitive is pipeline_p and eqn.params["mark_type"] == "start":
assert current_computation is None, (
"Defining a pipeline computation "
"inside a pipeline computation is "
"not allowed.")
current_computation = JaxPipelineComputation(
name=eqn.params["name"])
for var in eqn.invars:
if isinstance(var, Literal):
pass
elif var in global_consts_dir:
current_computation.consts_dir[var] = global_consts_dir[var]
else:
current_computation.invars.append(var)
assert current_computation is not None
current_computation.eqns.append(eqn)
if eqn.primitive is pipeline_p and eqn.params["mark_type"] == "end":
assert current_computation is not None, (
"Ending a pipeline computation before its start.")
assert current_computation.name == eqn.params["name"], (
"Ending a pipeline computation different from its start.")
for var in eqn.outvars:
current_computation.outvars.append(var)
result_computations.append(current_computation)
current_computation = None
return result_computations
def mark_missing_vars_in_backward_computation_pipeline_marks(
computations: Sequence[JaxPipelineComputation], global_invars,
global_outvars, gensym_func):
"""
Fix missing vars generated by jax.grad and alpa.grad.
Fix missing input variables in pipeline markers of stages generated by
jax.grad or alpa.grad. Also remove unused variables in the pipeline
markers.
"""
assert len(computations) % 2 == 0.
num_forward_computations = len(computations) // 2
var_computation_id = {}
for var in global_invars:
if not isinstance(var, Literal):
var_computation_id[var] = -1
computation_marked_to_unmarked_invars = [{} for _ in computations]
computation_weight_invars = [{} for _ in computations]
computation_additional_invars = [OrderedSet() for _ in computations]
computation_additional_outvars = [OrderedSet() for _ in computations]
for computation_id, computation in enumerate(computations):
for eqn in computation.eqns:
if eqn.primitive == pipeline_p and eqn.params[
"mark_type"] == "start":
for invar, outvar in zip(eqn.invars, eqn.outvars):
computation_marked_to_unmarked_invars[computation_id][
outvar] = invar
for var in eqn.invars:
if (not isinstance(var, Literal) and
var not in computation.consts_dir and
var not in computation.invars):
source_computation_id = var_computation_id[var]
if source_computation_id != computation_id:
# Special case for the model weights. If a backward
# computation is using an invar of a forward
# computation, do not let the invar go into the stage.
# Instead, we can directly use the original invar.
if (computation_id >= num_forward_computations and
source_computation_id
== 2 * num_forward_computations -
computation_id - 1 and
var in computation_marked_to_unmarked_invars[
source_computation_id]):
computation_weight_invars[computation_id][var] = (
computation_marked_to_unmarked_invars[
source_computation_id][var])
continue
# Mark all the variables in the backward computation
# that are not currently defined in pipeline markers.
if (source_computation_id != -1 and var not in
computations[source_computation_id].outvars):
computation_additional_outvars[
source_computation_id].add(var)
computation_additional_invars[computation_id].add(var)
for var in eqn.outvars:
var_computation_id[var] = computation_id
for var in global_outvars:
source_computation_id = var_computation_id[var]
if source_computation_id != -1 and var not in computations[
source_computation_id].outvars:
computation_additional_outvars[source_computation_id].add(var)
new_computations = []
for i, computation in enumerate(computations):
assert (computation.eqns[0].primitive is pipeline_p and
computation.eqns[0].params["mark_type"] == "start")
assert (computation.eqns[-1].primitive is pipeline_p and
computation.eqns[-1].params["mark_type"] == "end")
new_computation = JaxPipelineComputation(
computation.name, consts_dir=computation.consts_dir)
computation_var_mapping = {
var: gensym_func(var.aval)
for var in computation_additional_invars[i] |
computation_additional_outvars[i] |
computation_weight_invars[i].keys()
}
pipeline_start_invars = list(computation.eqns[0].invars)
pipeline_start_outvars = [
get_var_mapping(computation_var_mapping, var)
for var in computation.eqns[0].outvars
]
new_computation.invars = list(computation.invars)
for var in computation_additional_invars[i]:
pipeline_start_invars.append(var)
pipeline_start_outvars.append(computation_var_mapping[var])
for marked_var, unmarked_var in computation_weight_invars[i].items():
pipeline_start_invars.append(unmarked_var)
pipeline_start_outvars.append(computation_var_mapping[marked_var])
pipeline_start_invars_without_literal = []
pipeline_start_outvars_without_literal = []
for invar, outvar in zip(pipeline_start_invars, pipeline_start_outvars):
if isinstance(invar, Literal):
computation_var_mapping[outvar] = invar
else:
pipeline_start_invars_without_literal.append(invar)
pipeline_start_outvars_without_literal.append(outvar)
new_computation.invars = list(pipeline_start_invars_without_literal)
new_computation.eqns.append(computation.eqns[0]._replace(
invars=pipeline_start_invars_without_literal,
outvars=pipeline_start_outvars_without_literal))
for eqn in computation.eqns[1:-1]:
invars = [
get_var_mapping(computation_var_mapping, var)
for var in eqn.invars
]
outvars = [
get_var_mapping(computation_var_mapping, var)
for var in eqn.outvars
]
new_computation.eqns.append(
eqn._replace(invars=invars, outvars=outvars))
pipeline_end_invars = [
get_var_mapping(computation_var_mapping, var)
for var in computation.eqns[-1].invars
]
pipeline_end_outvars = list(computation.eqns[-1].outvars)
for var in computation_additional_outvars[i]:
pipeline_end_invars.append(computation_var_mapping[var])
pipeline_end_outvars.append(var)
pipeline_end_invars_without_dropvar = []
pipeline_end_outvars_without_dropvar = []
for invar, outvar in zip(pipeline_end_invars, pipeline_end_outvars):
if not isinstance(outvar, DropVar):
pipeline_end_invars_without_dropvar.append(invar)
pipeline_end_outvars_without_dropvar.append(outvar)
new_computation.outvars = list(pipeline_end_outvars_without_dropvar)
new_computation.eqns.append(computation.eqns[-1]._replace(
invars=pipeline_end_invars_without_dropvar,
outvars=pipeline_end_outvars_without_dropvar))
new_computations.append(new_computation)
return new_computations
def pipeline_dce(jax_pipeline_computations: Sequence[JaxPipelineComputation],
global_outvars):
"""
Clear unused vars cross pipeline computations.
This function removes grad and only keeps accumulated grad.
"""
def dce_pipe_marker(marker: JaxprEqn, used_set):
kept_indices = [
i for i, var in enumerate(marker.outvars) if var in used_set
]
new_marker = mark_pipeline_jaxpreqn(
[marker.invars[i] for i in kept_indices],
[marker.outvars[i] for i in kept_indices], marker.params["name"],
marker.params["mark_type"])
return new_marker
global_used = OrderedSet(global_outvars)
new_computations = []
for computation in reversed(jax_pipeline_computations):
new_eqns = []
# handle pipe end
pipe_end = computation.eqns[-1]
assert (pipe_end.primitive is pipeline_p and
pipe_end.params["mark_type"]
== "end"), "computation not ended by a pipeline marker"
new_pipe_end = dce_pipe_marker(pipe_end, global_used)
new_eqns.append(new_pipe_end)
# handle normal instructions
local_used = OrderedSet(new_pipe_end.invars)
for eqn in reversed(computation.eqns[1:-1]):
for outvar in eqn.outvars:
if not isinstance(outvar, DropVar) and outvar in local_used:
new_eqns.append(eqn)
local_used.update([
invar for invar in eqn.invars if isinstance(invar, Var)
])
break
# handle pipe start
pipe_start = computation.eqns[0]
assert (pipe_start.primitive is pipeline_p and
pipe_start.params["mark_type"]
== "start"), "computation not started by a pipeline marker"
new_pipe_start = dce_pipe_marker(pipe_start, local_used)
new_eqns.append(new_pipe_start)
global_used.update(new_pipe_start.invars)
new_eqns = list(reversed(new_eqns))
new_computation = JaxPipelineComputation(
computation.name,
invars=new_pipe_start.invars,
outvars=new_pipe_end.outvars,
eqns=new_eqns,
consts_dir=computation.consts_dir)
new_computations.append(new_computation)
new_computations = list(reversed(new_computations))
return new_computations
def _offload_remat_forward_remove_outvars(forward_stage, offloaded_eqns,
gensym_func):
removed_outvars = set()
removed_marker_mapping = {}
for eqn in offloaded_eqns:
not_dropped = [
var for var in eqn.outvars if not isinstance(var, DropVar)
]
removed_outvars.update(not_dropped)
previous_end = forward_stage.eqns[-1]
new_invars = []
new_outvars = []
for i, o in zip(previous_end.invars, previous_end.outvars):
if i in removed_outvars:
removed_marker_mapping[i] = o
continue
new_invars.append(i)
new_outvars.append(o)
add_dummy_dependency_var = (len(forward_stage.invars) != 0 or
len(new_outvars) != 0)
# TODO(zhuohan): Here we add a dummy byte from forward stage to
# backward stage to add a dependency link from the forward stage to
# the backward stage. Should not need this once we fixed the stage
# slicing in XLA.
new_eqns = list(forward_stage.eqns)
if add_dummy_dependency_var:
zero_literal = Literal(0, raise_to_shaped(get_aval(0)))
dummy_outvar = gensym_func(zero_literal.aval)
dummy_eqn = new_jaxpr_eqn([zero_literal, zero_literal], [dummy_outvar],
jax.lax.add_p, {})
new_eqns.insert(-1, dummy_eqn)
new_invars.append(dummy_outvar)
marked_dummy_outvar = gensym_func(dummy_outvar.aval)
new_outvars.append(marked_dummy_outvar)
else:
marked_dummy_outvar = None
new_eqns[-1] = mark_pipeline_jaxpreqn(new_invars, new_outvars,
previous_end.params["name"], "end")
new_forward = JaxPipelineComputation(forward_stage.name,
forward_stage.invars, new_outvars,
new_eqns, forward_stage.consts_dir)
return new_forward, removed_marker_mapping, marked_dummy_outvar
def _offload_remat_add_eqns(stage: JaxPipelineComputation, offloaded_eqns,
var_mapping, dummy_var, gensym_func):
removed_after_end_marker = set(var_mapping.values())
previous_start = stage.eqns[0]
new_invars = []
new_outvars = []
new_eqns = list(stage.eqns)
for i, o in zip(previous_start.invars, previous_start.outvars):
if i in removed_after_end_marker:
var_mapping[i] = o
continue
new_invars.append(i)
new_outvars.append(o)
if dummy_var:
new_invars.append(dummy_var)
new_outvars.append(gensym_func(dummy_var.aval))
new_eqns[0] = mark_pipeline_jaxpreqn(new_invars, new_outvars,
previous_start.params["name"], "start")
for eqn in offloaded_eqns:
mapped_outvars = [
var_mapping[var_mapping[var]] if
(var in var_mapping and var_mapping[var] in var_mapping) else var
for var in eqn.outvars
]
mapped_eqn = new_jaxpr_eqn(eqn.invars, mapped_outvars, eqn.primitive,
eqn.params, eqn.source_info)
new_eqns.insert(1, mapped_eqn)
new_stage = JaxPipelineComputation(stage.name, new_invars, stage.outvars,
new_eqns)
return new_stage
def offload_remat(jax_pipeline_computations: Sequence[JaxPipelineComputation],
gensym_func):
"""Offload remat call from forward to backward.
remat in Jax generates some remat_call in the forward part, but the output
of these remat_call is used only in the backward. Besides, these remat_call
only outputs constant. Hence, offloading them into the backward part does
not enlong any liveness interval, while helps reduce forward output size.
Args:
jax_pipeline_computations: pipeline stages including both forward and
backward, but no other.
gensym_func: gensym to create new Var different from existing ones.
Returns:
jax_pipeline_computations (Sequence[JaxPipelineComputation]):
computations after this transformation.
"""
def only_create_consts(jaxpr: Jaxpr):
const_vars = OrderedSet()
for eqn in jaxpr.eqns:
for var in eqn.invars:
if isinstance(var, Var) and var not in const_vars:
return False
const_vars.update(
[v for v in eqn.outvars if not isinstance(v, DropVar)])
return True
def only_input_consts(eqn: JaxprEqn):
in_bytes = 0
for var in eqn.invars:
if not isinstance(var, Var):
continue
if isinstance(var, DropVar):
continue
in_bytes += np.prod(var.aval.shape) * np.dtype(
var.aval.dtype).itemsize
return in_bytes == 0
num_layers = len(jax_pipeline_computations) // 2
new_computations = list(jax_pipeline_computations)
for i in range(num_layers):
forward_stage = new_computations[i]
offloaded_eqns = []
for eqn in reversed(forward_stage.eqns):
if eqn.primitive == pipeline_p:
continue
if (eqn.primitive == remat_call_p and
only_create_consts(eqn.params["call_jaxpr"]) and
only_input_consts(eqn)):
offloaded_eqns.append(eqn)
# remove outvars from forward stage
# assert len(offloaded_eqns)#, forward_stage.closed_jaxpr()
(new_forward, removed_var_mapping,
marked_dummy_outvar) = _offload_remat_forward_remove_outvars(
forward_stage, offloaded_eqns, gensym_func)
removed_var_post_marker = set(removed_var_mapping.values())
# remove invars and add eqn into backward stage
for stage_idx, stage in enumerate(new_computations):
if stage_idx == i:
continue
stage_invars = set(stage.invars)
if stage_invars.intersection(removed_var_post_marker):
dummy_outvar = (marked_dummy_outvar if
(stage_idx == num_layers * 2 - 1 - i) else None)
new_computations[stage_idx] = _offload_remat_add_eqns(
stage, offloaded_eqns, removed_var_mapping, dummy_outvar,
gensym_func)
new_computations[i] = new_forward
return new_computations
def rearrange_vars(vars,
selected: Sequence[Var],
pipe_marker=None,
is_input=True):
"""
Rearrange vars to let those in selected be first.
If the pipe_marker is given, rearrange invars and outvars in pipemarker as
well.
Args:
vars (Sequence[Var]): all vars to be rearranged.
selected (Sequence[Var]): vars selected to be prior.
pipe_marker (JaxprEqn): pipe marker corresponding to vars
is_input (bool): the var is input of pipe_marker, if False, it is output
"""
new_vars = list(selected)
selected = OrderedSet(selected)
for var in vars:
if var not in selected:
new_vars.append(var)
if pipe_marker is None:
return new_vars
if is_input:
new_invars = new_vars
invar_idx = {v: idx for idx, v in enumerate(pipe_marker.invars)}
new_outvars = [
pipe_marker.outvars[invar_idx[var]] for var in new_invars
]
else:
new_outvars = new_vars
outvar_idx = {v: idx for idx, v in enumerate(pipe_marker.outvars)}
new_invars = [
pipe_marker.invars[outvar_idx[var]] for var in new_outvars
]
new_marker = mark_pipeline_jaxpreqn(new_invars, new_outvars,
pipe_marker.params["name"],
pipe_marker.params["mark_type"])
return new_vars, new_marker
def generate_computations_from_modules(jax_computations, computation_names,
computation_modules, donate_invars,
donatable_lists, acc_grad_outvars,
stage_plan):
"""Generate pipeline computation from HLO modules."""
module_dict = dict(zip(computation_names, computation_modules))
computations = [
XlaShardedPipelineComputation.from_auto_sharded_computation(
sharding_annotated_module=module_dict[computation.name],
jax_pipeline_computation=computation,
stage_plan=stage_plan,
donated_invars=donate_invars,
acc_grad_outvars=acc_grad_outvars,
donatables=donatables)
for computation, donate_invars, donatables in zip(
jax_computations, donate_invars, donatable_lists)
]
return computations
def generate_sharded_xla_computations_arguments(
name: str, jax_computations: Sequence[JaxPipelineComputation],
computation_donate_invars: Sequence[bool],
output_sharding_dict: Dict[Var, pxla.ShardingSpec],
stage_input_sharding: Optional[Sequence[pxla.ShardingSpec]]):
"""
Generates the arguments for distributed compilation.
Similar to generate_sharded_xla_computations but only generate arguments.
"""
invars = OrderedSet()
outvars = OrderedSet()
donation_mapping = {}
eqns = []
consts_dir = {}
for computation, donation in zip(jax_computations,
computation_donate_invars):
consts_dir.update(computation.consts_dir)
# Do not add local invars into the invars
invars.update([var for var in computation.invars if var not in outvars])
outvars.update(computation.outvars)
for idx, var in enumerate(computation.invars):
if not donation[idx] or var not in invars:
continue
donation_mapping[computation.invars[idx]] = computation.outvars[idx]
eqns += computation.eqns
invars = rearrange_vars(invars, donation_mapping.keys())
outvars = rearrange_vars(outvars, donation_mapping.values())
jaxpr = Jaxpr(
constvars=list(consts_dir.keys()),
invars=list(invars),
outvars=list(outvars),
eqns=eqns,
)
donation_num = len(donation_mapping)
dummy_donated_invars = (True,) * donation_num + (False,) * (len(invars) -
donation_num)
closed_jaxpr = ClosedJaxpr(jaxpr, consts_dir.values())
backend_name = "gpu"
backend = xb.get_backend(backend_name)
hlo_module = jaxpr_to_hlo_module(name, closed_jaxpr, dummy_donated_invars,
backend)
if output_sharding_dict:
sharding_protos = [
output_sharding_dict[x].sharding_proto() for x in outvars
]
xe.set_hlo_module_output_shardings(hlo_module, sharding_protos)
if stage_input_sharding:
sharding_protos = [
sharding_spec.sharding_proto()
for sharding_spec in stage_input_sharding
]
xe.set_hlo_module_input_shardings(hlo_module, sharding_protos)
flops = xe.hlo_module_count_flop_dot_conv_only(hlo_module)
return hlo_module, flops
def generate_sharded_xla_computations(
name: str, jax_computations: Sequence[JaxPipelineComputation],
computation_donate_invars, donatable_lists, acc_grad_outvars,
num_micro_batches, logical_mesh, autosharding_option,
output_sharding_dict, stage_input_sharding):
"""
Generate sharded XLA computations.
It runs the auto-sharding pass on the given JaxPipelineComputations.
Note: we merge the co-located forward and backward computation and compile
them together to get a sharding strategy config.
"""
hlo_module, flops = generate_sharded_xla_computations_arguments(
name, jax_computations, computation_donate_invars, output_sharding_dict,
stage_input_sharding)
# pylint: disable=unbalanced-tuple-unpacking
(computation_names, computation_modules,
stage_plan) = run_auto_sharding_pass(hlo_module, logical_mesh, "stages",
num_micro_batches,
autosharding_option)
computations = generate_computations_from_modules(
jax_computations, computation_names, computation_modules,
computation_donate_invars, donatable_lists, acc_grad_outvars,
stage_plan)
return computations, flops
def rewrite_hook(eqns, gensym_fn):
"""TODO(zhuohan)."""
for idx, eqn in enumerate(eqns):
eqn: JaxprEqn
if ("mark_type" in eqn.params and eqn.params["mark_type"] == "hook"):
used_vars = OrderedSet()
defined_vars = OrderedSet()
for e in eqns[0:idx]:
defined_vars.update(
[v for v in e.outvars if not isinstance(v, DropVar)])
for e in eqns[idx + 1:]:
used_vars.update([v for v in e.invars if isinstance(v, Var)])
marked = used_vars.intersection(defined_vars)
hooked = list(marked)
new_hook = mark_hook_jaxpreqn(hooked,
[gensym_fn(v.aval) for v in hooked])
rewrite_dict = dict(zip(hooked, new_hook.outvars))
eqns[idx] = new_hook
for i in range(idx + 1, len(eqns)):
e = eqns[i]
eqns[i] = new_jaxpr_eqn(
[get_var_mapping(rewrite_dict, v) for v in e.invars],
e.outvars, e.primitive, e.params)
return new_hook
return None
def _wrap_with_call(closed_jaxpr: ClosedJaxpr, invars, outvars, name):
new_invars = closed_jaxpr.jaxpr.invars + closed_jaxpr.jaxpr.constvars
jaxpr = clone_jaxpr(closed_jaxpr, new_invars, constvars=[]).jaxpr
params = dict(name=name, call_jaxpr=jaxpr)
return new_jaxpr_eqn(invars + closed_jaxpr.consts,
outvars,
named_call_p,
params=params)
def _rearrange_in_out_for_donation(invars, outvars, donation_map):
outvar_set = set(outvars)
donated_invars = [
var for var in invars
if (var in donation_map and donation_map[var] in outvar_set)
]
donated_outvars = [donation_map[var] for var in donated_invars]
invars = rearrange_vars(invars, donated_invars)
outvars = rearrange_vars(outvars, donated_outvars)
num_donated = len(donated_invars)
return invars, outvars, num_donated
def merge_unmarked_with_call(jaxprs: Sequence[ClosedJaxpr],
names: Sequence[str],
outvars,
donation_map=None):
"""Merge a sequence of jaxprs (no pipeline marker) using named call."""
gensym_fn = gensym([closed_jaxpr.jaxpr for closed_jaxpr in jaxprs])
eqns = []
invars = OrderedSet()
intermediates = OrderedSet()
const_dir = {}
for stage_name, closed_jaxpr in zip(names, jaxprs):
invars.update(closed_jaxpr.jaxpr.invars)
intermediates.update(closed_jaxpr.jaxpr.outvars)
const_dir.update(zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts))
jaxpr = closed_jaxpr.jaxpr
sym_invars = [gensym_fn(var.aval) for var in jaxpr.invars]
sym_outvars = [gensym_fn(var.aval) for var in jaxpr.outvars]
eqns.append(
mark_pipeline_jaxpreqn(jaxpr.invars, sym_invars, stage_name,
"start"))
eqns.append(
_wrap_with_call(closed_jaxpr, sym_invars, sym_outvars, stage_name))
eqns.append(
mark_pipeline_jaxpreqn(sym_outvars, jaxpr.outvars, stage_name,
"end"))
invars.difference_update(intermediates)
# handle donation
num_donated = 0
if donation_map:
(invars, outvars,
num_donated) = _rearrange_in_out_for_donation(invars, outvars,
donation_map)
is_donated = [True] * num_donated + [False] * (len(invars) - num_donated)