-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathauto_sharding_util.cc
2276 lines (2015 loc) · 79 KB
/
auto_sharding_util.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include "tensorflow/compiler/xla/service/spmd/auto_sharding_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/spmd/auto_sharding_strategy.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
namespace spmd {
inline const HloInstruction* PassThroughCustomCallMarkerGetSource(
const HloInstruction* ins);
inline HloInstruction* PassThroughCustomCallMarkerUser(
HloInstruction* raw_user, const HloInstruction* inst);
NullStream& NullStream::Global() {
static NullStream stream;
return stream;
}
const char* const kPipelineMarker = "pipeline_marker";
const char* const kCrossMeshAllReduce = "__builtin$CrossMeshAllReduce";
// Return whether a reshape instruction is a special reshape that switches
// the batch dim of a dot.
bool IsBatchDimSwitchReshape(const HloInstruction* inst) {
if (inst->opcode() != HloOpcode::kReshape) {
return false;
}
if (inst->users().size() != 1) {
return false;
}
const HloInstruction* operand = inst->operand(0);
const HloInstruction* user = inst->users().front();
if (operand->opcode() != HloOpcode::kDot) {
return false;
}
int batch_dims = operand->dot_dimension_numbers().lhs_batch_dimensions_size();
if (batch_dims <= 0) {
return false;
}
if (user->opcode() != HloOpcode::kTranspose) {
return false;
}
return true;
}
// Return whether the instruction is followed by a broadcast.
bool IsFollowedByBroadcast(const HloInstruction* ins) {
const int max_depth = 6;
for (int i = 0; i < max_depth; ++i) {
if (ins->users().empty()) {
return false;
}
ins = PassThroughCustomCallMarkerUser(ins->users().front(), ins);
if (ins->opcode() == HloOpcode::kBroadcast) {
return true;
} else if (ins->opcode() == HloOpcode::kReshape) {
i--;
}
}
return false;
}
// Return whether the instruction is an activation from another pipeline stage.
bool IsActivationFromAnotherStage(const HloInstruction* ins,
const InstructionBatchDimMap& batch_dim_map) {
if (!(ins->opcode() == HloOpcode::kParameter && batch_dim_map.count(ins))) {
return false;
}
for (const HloInstruction* user : ins->users()) {
if (!(user->opcode() == HloOpcode::kTuple && user->users().size() == 1 &&
user->users().front()->IsCustomCall(kPipelineMarker) &&
user->users().front()->metadata().op_type().find("start") !=
std::string::npos)) {
return false;
}
}
if (primitive_util::IsIntegralType(ins->shape().element_type())) {
// TODO(lmzheng): This is a temporary hack. We use this to filter out
// the input word ids and position ids. These are global input so they are
// not activations from the previous stage. If we do not filter out them,
// some follow-up instructions will follow the wrong instructions.
return false;
}
return true;
}
// Propagate sharding for broadcast.
// The output will be tiled along the broadcasted dimension the same way
// as the input for the broadcast while the other dimensions are kept
// non-tiled.
HloSharding BroadcastSharding(const HloSharding& input_spec,
const Shape& new_shape,
const absl::Span<const int64_t>& dimensions) {
if (input_spec.IsReplicated()) {
return input_spec;
}
CHECK(new_shape.IsArray());
std::vector<int64_t> target_tile_assignment_dimensions;
for (int64_t i = 0; i < new_shape.rank(); ++i) {
auto it = absl::c_find(dimensions, i);
if (it == dimensions.end()) {
target_tile_assignment_dimensions.push_back(1);
} else {
const int64_t source_dim = std::distance(dimensions.begin(), it);
target_tile_assignment_dimensions.push_back(
input_spec.tile_assignment().dim(source_dim));
}
}
if (input_spec.ReplicateOnLastTileDim()) {
target_tile_assignment_dimensions.push_back(
input_spec.tile_assignment().dimensions().back());
}
Array<int64_t> new_tile_assignment = input_spec.tile_assignment();
new_tile_assignment.Reshape(target_tile_assignment_dimensions);
return input_spec.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
}
// Propagate sharding for dim-wise operations (e.g., slice, pad) which works
// independently on each dimension.
// The sharding can successfully propagate if the operation only happens
// on tensor dimensions that are not tiled.
std::optional<HloSharding> PropagateDimwiseSharding(
const HloSharding& input_spec, const Shape& old_shape,
const Shape& new_shape) {
if (input_spec.IsReplicated()) {
return input_spec;
}
CHECK(old_shape.IsArray());
const auto& tile_assignment = input_spec.tile_assignment();
for (int64_t i = 0; i < old_shape.rank(); ++i) {
if (tile_assignment.dim(i) > 1 &&
new_shape.dimensions(i) != old_shape.dimensions(i)) {
return std::nullopt;
}
}
return input_spec;
}
// Propagate sharding for ReduceWindow-like operations.
// The sharding can successfully propagate if the window operation only happens
// on tensor dimensions that are not tiled.
std::optional<HloSharding> PropagateReduceWindowSharding(
const HloSharding& input_spec, const Shape& old_shape,
const Window& window) {
if (input_spec.IsReplicated()) {
return input_spec;
}
CHECK(!input_spec.IsTuple());
const auto& tile_assignment = input_spec.tile_assignment();
for (int64_t i = 0; i < old_shape.rank(); ++i) {
if (tile_assignment.dim(i) > 1 && window.dimensions(i).size() != 1) {
return std::nullopt;
}
}
return input_spec;
}
// Pass through the custom call marker and get the source instruction
inline const HloInstruction* PassThroughCustomCallMarkerGetSource(
const HloInstruction* ins) {
while (ins->opcode() == HloOpcode::kGetTupleElement &&
IsPassThroughTuple(ins->operand(0))) {
const HloInstruction* custom_call = ins->operand(0);
const HloInstruction* tuple = custom_call->operand(0);
while (IsPassThroughTuple(tuple)) {
tuple = tuple->operand(0);
}
ins = tuple->operand(ins->tuple_index());
}
return ins;
}
// Depth analysis (breadth first search).
// We also assign a much larger distance to heavy operators (e.g., dot,
// convolution).
InstructionDepthMap BuildInstructionDepthMap(
const HloInstructionSequence& sequence,
const InstructionBatchDimMap& batch_dim_map) {
const std::vector<HloInstruction*>& instructions = sequence.instructions();
InstructionDepthMap depth_map;
absl::flat_hash_map<const HloInstruction*, size_t> degree_dict;
// Init frontier
size_t collected = 0;
std::vector<const HloInstruction*> current_frontier;
for (const HloInstruction* inst : instructions) {
degree_dict[inst] = inst->unique_operands().size();
if (degree_dict[inst] == 0) {
depth_map[inst] = 0;
// Add some initial depth for activations from other pipeline stages.
if (IsActivationFromAnotherStage(inst, batch_dim_map)) {
depth_map[inst] = 20;
}
current_frontier.push_back(inst);
collected++;
}
}
// Push forward
std::vector<const HloInstruction*> next_frontier;
while (collected < instructions.size()) {
CHECK(!current_frontier.empty());
next_frontier.clear();
for (const HloInstruction* inst : current_frontier) {
for (const HloInstruction* node : inst->users()) {
int now_degree = --degree_dict[node];
if (now_degree == 0) {
int64_t delta = 0;
bool reset = false;
// Heavy operators have more weight (distance).
switch (node->opcode()) {
case HloOpcode::kDot:
case HloOpcode::kConvolution:
delta = 1000;
break;
// A temporary hack here: reduce ops will generate replicated
// sharding. We do not want the later broadcast and elementwise ops
// to follow it. So we give reduce ops some penalty and let the
// elementwise ops to follow other operands.
// TODO(lmzheng): remove this hack by correctly registering
// strategies for broadcast.
case HloOpcode::kReduce:
reset = true;
break;
// For similar reasons mentioned above, we give some penalty to
// broadcast.
case HloOpcode::kBroadcast:
delta = -5;
break;
case HloOpcode::kReshape:
delta = 0;
break;
default:
delta = 1;
break;
}
if (reset) {
depth_map[node] = 0;
} else if (node->opcode() == HloOpcode::kGetTupleElement &&
IsPassThroughTuple(node->operand(0))) {
depth_map[node] =
depth_map.at(PassThroughCustomCallMarkerGetSource(node));
} else {
int64_t max_depth = depth_map.at(inst) + delta;
for (const HloInstruction* operand : node->operands()) {
max_depth = std::max(max_depth, depth_map.at(operand) + delta);
}
depth_map[node] = max_depth;
}
next_frontier.push_back(node);
collected += 1;
}
}
}
std::swap(current_frontier, next_frontier);
}
return depth_map;
}
// Batch dimension analysis that finds the batch dimension of each instruction.
InstructionBatchDimMap BuildInstructionBatchDimMap(
const HloInstructionSequence& sequence) {
InstructionBatchDimMap batch_map;
const std::vector<HloInstruction*>& instructions = sequence.instructions();
// We use the first dot or convolution as the source to start batch dim
// propagation. Assume the first dim of the first dot is the batch dim.
int batch_dim_of_source = 0;
// Find the source of batch_dim propagation
bool set_the_next_dot_conv = true;
for (const HloInstruction* ins : instructions) {
if (ins->opcode() == HloOpcode::kDot ||
ins->opcode() == HloOpcode::kConvolution) {
if (set_the_next_dot_conv) {
set_the_next_dot_conv = false;
batch_map[ins] = batch_dim_of_source;
}
}
if (ins->IsCustomCall(kPipelineMarker) &&
ins->metadata().op_type().find("start") != std::string::npos) {
// Reset the status after meet a new pipeline marker.
set_the_next_dot_conv = true;
}
}
// Forward propagation: propagate from operand
for (const HloInstruction* ins : instructions) {
switch (ins->opcode()) {
case HloOpcode::kParameter:
case HloOpcode::kConstant:
case HloOpcode::kIota:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kRng:
break;
case HloOpcode::kBroadcast: {
const HloInstruction* operand = ins->operand(0);
const auto& dimensions = ins->dimensions();
if (batch_map.count(operand)) {
int value = batch_map[operand];
int old_dim = -1;
for (int i = 0; i < ins->shape().rank(); ++i) {
if (absl::c_linear_search(dimensions, i)) {
old_dim++;
}
if (old_dim == value) {
batch_map[ins] = i;
break;
}
}
}
break;
}
case HloOpcode::kReshape: {
const HloInstruction* operand = ins->operand(0);
if (batch_map.count(operand)) {
int value = batch_map[operand];
int64_t batch_size = operand->shape().dimensions(value);
int pt_operand = 0;
int pt_ins = 0;
auto skip_one_dims = [&]() {
if (batch_size != 1) {
while (pt_operand + 1 < operand->shape().rank() &&
operand->shape().dimensions(pt_operand) == 1) {
pt_operand += 1;
}
while (pt_ins + 1 < ins->shape().rank() &&
ins->shape().dimensions(pt_ins) == 1) {
pt_ins += 1;
}
}
};
skip_one_dims();
bool match = true;
while (pt_operand < value) {
if (operand->shape().dimensions(pt_operand) !=
ins->shape().dimensions(pt_ins)) {
match = false;
break;
}
pt_operand += 1;
pt_ins += 1;
skip_one_dims();
}
if (match) {
batch_map[ins] = pt_ins;
}
}
break;
}
case HloOpcode::kTranspose: {
const HloInstruction* operand = ins->operand(0);
const auto& dimensions = ins->dimensions();
if (batch_map.count(operand)) {
int value = batch_map[operand];
auto it = absl::c_find(dimensions, value);
batch_map[ins] = it - dimensions.begin();
}
break;
}
case HloOpcode::kReverse:
case HloOpcode::kPad:
case HloOpcode::kSlice:
case HloOpcode::kConcatenate:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
// Unary elementwise operations.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kConvert:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kReducePrecision:
case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
// Binary elementwise operations
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
// Ternary elementwise operations.
case HloOpcode::kSelect:
case HloOpcode::kClamp: {
for (const HloInstruction* operand : ins->unique_operands()) {
if (batch_map.count(operand)) {
int value = batch_map[operand];
if (operand->shape().rank() == ins->shape().rank() &&
operand->shape().dimensions(value) ==
ins->shape().dimensions(value)) {
batch_map[ins] = batch_map[operand];
break;
}
}
}
break;
}
case HloOpcode::kReduce: {
const HloInstruction* operand = ins->operand(0);
const auto& dimensions = ins->dimensions();
if (batch_map.count(operand)) {
int value = batch_map[operand];
if (value == 0 && !absl::c_linear_search(dimensions, value)) {
batch_map[ins] = value;
}
}
break;
}
case HloOpcode::kDot: {
const HloInstruction* lhs = ins->operand(0);
const HloInstruction* rhs = ins->operand(1);
const auto& dot_dnums = ins->dot_dimension_numbers();
int64_t space_base_dim = dot_dnums.lhs_batch_dimensions_size();
const auto& lhs_batch_dims =
ins->dot_dimension_numbers().lhs_batch_dimensions();
const auto& rhs_batch_dims =
ins->dot_dimension_numbers().rhs_batch_dimensions();
std::vector<int64_t> lhs_space_dims, rhs_space_dims;
std::tie(lhs_space_dims, rhs_space_dims) =
GetSpaceDims(lhs->shape(), rhs->shape(), dot_dnums);
if (batch_map.count(lhs)) {
int value = batch_map[lhs];
for (int i = 0; i < lhs_batch_dims.size(); ++i) {
if (value == lhs_batch_dims[i]) {
batch_map[ins] = i;
break;
}
}
if (!lhs_space_dims.empty() && value == lhs_space_dims[0]) {
batch_map[ins] = space_base_dim;
}
}
if (batch_map.count(rhs)) {
int value = batch_map[rhs];
for (int i = 0; i < rhs_batch_dims.size(); ++i) {
if (value == rhs_batch_dims[i]) {
batch_map[ins] = i;
break;
}
}
if (!rhs_space_dims.empty() && value == rhs_space_dims[0]) {
batch_map[ins] = space_base_dim + 1;
}
}
break;
}
case HloOpcode::kConvolution: {
const HloInstruction* lhs = ins->operand(0);
const HloInstruction* rhs = ins->operand(1);
const auto& conv_dnums = ins->convolution_dimension_numbers();
if (batch_map.count(lhs)) {
int value = batch_map[lhs];
if (value == conv_dnums.input_batch_dimension()) {
batch_map[ins] = conv_dnums.output_batch_dimension();
}
}
if (batch_map.count(rhs)) {
int value = batch_map[rhs];
if (value == conv_dnums.kernel_output_feature_dimension()) {
batch_map[ins] = conv_dnums.output_feature_dimension();
}
}
break;
}
case HloOpcode::kGather:
case HloOpcode::kScatter: {
// We only handle one case for now:
// If gather/scatter does not happen on the batch dimension,
// then we can propagate the batch dim.
const HloInstruction* operand = ins->operand(0);
if (batch_map.count(operand)) {
int value = batch_map[operand];
if (ins->shape().rank() == operand->shape().rank() &&
ins->shape().dimensions(value) ==
operand->shape().dimensions(value)) {
batch_map[ins] = value;
}
}
break;
}
case HloOpcode::kSort: {
for (size_t i = 0; i < ins->operand_count(); ++i) {
const HloInstruction* operand = ins->operand(i);
if (batch_map.count(operand)) {
int value = batch_map[operand];
if (!absl::c_linear_search(ins->dimensions(), value)) {
batch_map[ins] = value;
break;
}
}
}
break;
}
case HloOpcode::kGetTupleElement: {
const HloInstruction* source =
PassThroughCustomCallMarkerGetSource(ins);
if (batch_map.count(source)) {
batch_map[ins] = batch_map[source];
}
break;
}
case HloOpcode::kTuple:
case HloOpcode::kCustomCall:
case HloOpcode::kOptimizationBarrier:
break;
case HloOpcode::kWhile:
break;
default:
LOG(FATAL) << "Unhandled instruction: " + ins->ToString();
}
}
// Backward propagation: propagate to operands
for (int64_t i = instructions.size() - 1; i >= 0; i--) {
const HloInstruction* ins = instructions[i];
switch (ins->opcode()) {
case HloOpcode::kBroadcast: {
const HloInstruction* operand = ins->operand(0);
const auto& dimensions = ins->dimensions();
if (batch_map.count(ins) && !batch_map.count(operand)) {
int value = batch_map[ins];
int old_dim = -1;
for (int i = 0; i < ins->shape().rank(); ++i) {
if (absl::c_linear_search(dimensions, i)) {
old_dim++;
if (i == value) {
batch_map[operand] = old_dim;
break;
}
}
}
}
break;
}
case HloOpcode::kReshape: {
const HloInstruction* operand = ins->operand(0);
if (batch_map.count(ins) && !batch_map.count(operand)) {
int value = batch_map[ins];
int64_t batch_size = ins->shape().dimensions(value);
int pt_operand = 0;
int pt_ins = 0;
auto skip_one_dims = [&]() {
if (batch_size != 1) {
while (pt_operand + 1 < operand->shape().rank() &&
operand->shape().dimensions(pt_operand) == 1) {
pt_operand += 1;
}
while (pt_ins + 1 < ins->shape().rank() &&
ins->shape().dimensions(pt_ins) == 1) {
pt_ins += 1;
}
}
};
skip_one_dims();
bool match = true;
while (pt_ins < value) {
if (operand->shape().dimensions(pt_operand) !=
ins->shape().dimensions(pt_ins)) {
match = false;
break;
}
pt_operand += 1;
pt_ins += 1;
skip_one_dims();
}
if (match) {
batch_map[operand] = pt_operand;
}
}
break;
}
case HloOpcode::kTranspose: {
const HloInstruction* operand = ins->operand(0);
const auto& dimensions = ins->dimensions();
if (batch_map.count(ins) && !batch_map.count(operand)) {
batch_map[operand] = dimensions[batch_map[ins]];
}
break;
}
case HloOpcode::kReverse:
case HloOpcode::kPad:
case HloOpcode::kSlice:
case HloOpcode::kConcatenate:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
// Unary elementwise operations.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kConvert:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kReducePrecision:
case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
// Binary elementwise operations
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
// Ternary elementwise operations.
case HloOpcode::kSelect:
case HloOpcode::kClamp: {
if (batch_map.count(ins)) {
int value = batch_map[ins];
for (const HloInstruction* operand : ins->unique_operands()) {
if (!batch_map.count(operand) &&
operand->shape().rank() == ins->shape().rank() &&
operand->shape().dimensions(value) ==
ins->shape().dimensions(value)) {
batch_map[operand] = value;
}
}
}
break;
}
case HloOpcode::kReduce: {
const HloInstruction* operand = ins->operand(0);
const auto& dimensions = ins->dimensions();
if (batch_map.count(ins) && !batch_map.count(operand)) {
int value = batch_map[ins];
if (value == 0 && !absl::c_linear_search(dimensions, value)) {
batch_map[operand] = value;
}
}
break;
}
case HloOpcode::kDot: {
const HloInstruction* lhs = ins->operand(0);
const HloInstruction* rhs = ins->operand(1);
const auto& dot_dnums = ins->dot_dimension_numbers();
int64_t space_base_dim = dot_dnums.lhs_batch_dimensions_size();
const auto& lhs_batch_dims =
ins->dot_dimension_numbers().lhs_batch_dimensions();
const auto& rhs_batch_dims =
ins->dot_dimension_numbers().rhs_batch_dimensions();
std::vector<int64_t> lhs_space_dims, rhs_space_dims;
std::tie(lhs_space_dims, rhs_space_dims) =
GetSpaceDims(lhs->shape(), rhs->shape(), dot_dnums);
if (batch_map.count(ins)) {
int value = batch_map[ins];
if (!batch_map.count(lhs)) {
for (int i = 0; i < lhs_batch_dims.size(); ++i) {
if (value == i) {
batch_map[lhs] = lhs_batch_dims[i];
break;
}
}
if (!lhs_space_dims.empty() && value == space_base_dim) {
batch_map[lhs] = lhs_space_dims[0];
}
}
if (!batch_map.count(rhs)) {
for (int i = 0; i < rhs_batch_dims.size(); ++i) {
if (value == i) {
batch_map[rhs] = rhs_batch_dims[i];
break;
}
}
if (!rhs_space_dims.empty() && value == space_base_dim + 1) {
batch_map[rhs] = rhs_space_dims[0];
}
}
}
break;
}
case HloOpcode::kConvolution: {
const HloInstruction* lhs = ins->operand(0);
const HloInstruction* rhs = ins->operand(1);
const auto& conv_dnums = ins->convolution_dimension_numbers();
if (batch_map.count(ins)) {
int value = batch_map[ins];
if (value == conv_dnums.output_batch_dimension() &&
!batch_map.count(lhs)) {
batch_map[lhs] = conv_dnums.input_batch_dimension();
}
if (value == conv_dnums.output_feature_dimension() &&
!batch_map.count(rhs)) {
batch_map[rhs] = conv_dnums.kernel_output_feature_dimension();
}
}
break;
}
case HloOpcode::kGather:
case HloOpcode::kScatter: {
// We only handle one case for now:
// If gather/scatter does not happen on the batch dimension,
// then we can propagate the batch dim.
if (batch_map.count(ins)) {
int value = batch_map[ins];
const HloInstruction* operand = ins->operand(0);
if (ins->shape().rank() == operand->shape().rank() &&
ins->shape().dimensions(value) ==
operand->shape().dimensions(value)) {
batch_map[operand] = value;
}
}
break;
}
case HloOpcode::kSort: {
if (batch_map.count(ins)) {
int value = batch_map[ins];
if (!absl::c_linear_search(ins->dimensions(), value)) {
for (size_t i = 0; i < ins->operand_count(); ++i) {
const HloInstruction* operand = ins->operand(i);
batch_map[operand] = value;
}
}
}
break;
}
case HloOpcode::kGetTupleElement: {
const HloInstruction* source =
PassThroughCustomCallMarkerGetSource(ins);
if (batch_map.count(ins) && !batch_map.count(source)) {
batch_map[source] = batch_map[ins];
}
break;
}
case HloOpcode::kTuple:
case HloOpcode::kCustomCall:
case HloOpcode::kOptimizationBarrier:
break;
default:
break;
}
}
// Print batch map for debugging
// std::cerr << "Batch dim map begin" << std::endl;
// for (const HloInstruction* ins : instructions) {
// std::cerr << ins->ToString();
// if (batch_map.count(ins)) {
// std::cerr << " BATCH " << batch_map[ins] << std::endl;
// } else {
// std::cerr << " NOBATCH " << std::endl;
// }
// }
// std::cerr << "Batch dim map end" << std::endl;
return batch_map;
}
// Remove duplicated strategies with the same output sharding spec.
void RemoveDuplicatedStrategy(std::unique_ptr<StrategyVector>& strategies) {
std::vector<ShardingStrategy> new_vector;
absl::flat_hash_set<HloSharding> added;
CHECK(!strategies->is_tuple);
for (size_t i = 0; i < strategies->leaf_vector.size(); ++i) {
if (!added.count(strategies->leaf_vector[i].output_sharding)) {
added.insert(strategies->leaf_vector[i].output_sharding);
new_vector.push_back(std::move(strategies->leaf_vector[i]));
}
}
strategies->leaf_vector = std::move(new_vector);
}
// Remove strategies whose output tensor's shape is not divisible by the tile
// factors defined in the sharding spec.
void RemoveIndivisibleStrategies(std::unique_ptr<StrategyVector>& strategies,
const Shape& shape) {
std::vector<ShardingStrategy> new_vector;
CHECK(!strategies->is_tuple);
for (size_t i = 0; i < strategies->leaf_vector.size(); ++i) {
bool divisible = true;
const HloSharding& output_spec = strategies->leaf_vector[i].output_sharding;
if (!output_spec.IsReplicated()) {
CHECK(output_spec.IsTiled());
const Array<int64_t>& tile_assignment = output_spec.tile_assignment();
for (size_t j = 0; j < shape.rank(); ++j) {
if (shape.dimensions(j) % tile_assignment.dim(j) != 0) {
divisible = false;
break;
}
}
}
if (divisible) {
new_vector.push_back(std::move(strategies->leaf_vector[i]));
}
}
strategies->leaf_vector = std::move(new_vector);
}
// Filter strategies according to the solver_option.force_batch_dim_to_mesh_dim.
// This can be used to forcibly generate data-parallel strategies.
Status FilterStrategy(const HloInstruction* ins,
std::unique_ptr<StrategyVector>& strategies,
const ClusterEnvironment& cluster_env,
const InstructionBatchDimMap& batch_map,
const AutoShardingSolverOption& solver_option) {
int mesh_dim = solver_option.force_batch_dim_to_mesh_dim;
int batch_dim = batch_map.at(ins);
const Array<int64_t>& device_mesh = cluster_env.device_mesh;
if (ins->shape().dimensions(batch_dim) % device_mesh.dim(mesh_dim) != 0) {
return tensorflow::errors::InvalidArgument(
"The length of batch dimension is "
"not divisible by the number of devices. " +
ins->ToString());
}
std::vector<ShardingStrategy> new_leaf_vector;
for (auto& stra : strategies->leaf_vector) {
std::vector<int> tensor_dim_to_mesh_dim =
cluster_env.GetTensorDimToMeshDim(ins->shape(), stra.output_sharding);
if (device_mesh.dim(mesh_dim) > 1) {
// If the mesh dim is not one, the output tensor must be
// tiled along the mesh dim.
if (tensor_dim_to_mesh_dim[batch_dim] == mesh_dim) {
new_leaf_vector.push_back(std::move(stra));
}
} else {
// If the mesh dim is one, the output tensor must be replicated
// on the mesh dim.
if (tensor_dim_to_mesh_dim[batch_dim] == -1) {
new_leaf_vector.push_back(std::move(stra));
}
}
}
CHECK(!new_leaf_vector.empty())
<< ins->ToString() << " does not have any valid strategies";
strategies->leaf_vector = std::move(new_leaf_vector);
return Status::OK();
}
inline std::pair<int, int> ParseMeshDims(const std::string& strategy_name) {
if (strategy_name.find("{0,1}") != std::string::npos) {
return {0, 1};
} else {
return {1, 0};
}
}
// Return whether the tensor shape is divisible by
// the number of devices along multiple dimensions.
bool IsDivisible(const HloInstruction* ins, const Array<int64_t>& device_mesh,
const std::vector<int64_t>& tensor_dims,
const std::vector<int64_t>& mesh_dims) {
CHECK_EQ(tensor_dims.size(), mesh_dims.size());
for (int64_t i = 0; i < tensor_dims.size(); ++i) {
if (ins->shape().dimensions(tensor_dims[i]) %
device_mesh.dim(mesh_dims[i]) !=
0) {
return false;
}
}
return true;
}
// Return the output sharding of the reduce-scatter variant of a given strategy.
HloSharding GetReduceScatterOutput(const HloInstruction* ins,
const ShardingStrategy& strategy,
const ClusterEnvironment& cluster_env) {
const Array<int64_t>& device_mesh = cluster_env.device_mesh;
const Array<int64_t>& device_mesh_1d = cluster_env.device_mesh_1d;
if (ins->opcode() == HloOpcode::kDot) {
const DotDimensionNumbers& dot_dnums = ins->dot_dimension_numbers();
int64_t space_base_dim = dot_dnums.lhs_batch_dimensions_size();
if (StrStartsWith(strategy.name, "SR = SS x SR") ||
StrStartsWith(strategy.name, "RS = RS x SS")) {