-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathauto_sharding.cc
2289 lines (2042 loc) · 89.7 KB
/
auto_sharding.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include "tensorflow/compiler/xla/service/spmd/auto_sharding.h"
#include "pybind11/numpy.h"
#include "pybind11/stl.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/spmd/auto_sharding_strategy.h"
namespace xla {
namespace spmd {
// Create a HloSharding that tiles some tensor dims on some device mesh dims.
HloSharding Tile(const Shape& shape, const std::vector<int64_t> tensor_dims,
const std::vector<int64_t> mesh_dims,
const Array<int64_t>& device_mesh) {
CHECK_EQ(tensor_dims.size(), mesh_dims.size());
CHECK(shape.IsArray());
std::vector<int64_t> tile_assignment_dimensions(shape.rank(), 1);
// Split on certain mesh dimensions
int64_t split_prod = 1;
for (size_t i = 0; i < tensor_dims.size(); ++i) {
tile_assignment_dimensions[tensor_dims[i]] = device_mesh.dim(mesh_dims[i]);
split_prod *= device_mesh.dim(mesh_dims[i]);
}
// Replicate on reminding mesh dimensions
bool replicate_on_last_tile_dim = false;
if (split_prod < device_mesh.num_elements()) {
tile_assignment_dimensions.push_back(device_mesh.num_elements() /
split_prod);
replicate_on_last_tile_dim = true;
}
// Map device ids from device_mesh to tile_assignment_devices
std::vector<int64_t> tile_assignment_devices;
tile_assignment_devices.reserve(device_mesh.num_elements());
std::vector<int64_t> tmp_indices(device_mesh.num_dimensions(), 0);
std::function<void(int64_t, std::vector<int64_t>)>
generate_tile_assignment_devices;
generate_tile_assignment_devices = [&](int64_t tensor_dim,
std::vector<int64_t> mesh_indices) {
if (tensor_dim == shape.rank() - 1) {
AppendFlattenElements(&tile_assignment_devices, device_mesh, mesh_indices,
-1, tmp_indices);
} else {
int64_t next_tensor_dim = tensor_dim + 1;
int64_t next_mesh_dim = -1;
int64_t index = GetIndex(tensor_dims, next_tensor_dim);
if (index >= 0) {
next_mesh_dim = mesh_dims[index];
}
for (int64_t i = 0; i < tile_assignment_dimensions[next_tensor_dim];
++i) {
if (next_mesh_dim != -1) {
mesh_indices[next_mesh_dim] = i;
}
generate_tile_assignment_devices(next_tensor_dim, mesh_indices);
}
}
};
std::vector<int64_t> mesh_indices(device_mesh.num_dimensions(), -1);
generate_tile_assignment_devices(-1, mesh_indices);
// Make HloSharding
Array<int64_t> tile_assignment(tile_assignment_dimensions);
tile_assignment.SetValues(tile_assignment_devices);
return replicate_on_last_tile_dim
? HloSharding::PartialTile(std::move(tile_assignment))
: HloSharding::Tile(std::move(tile_assignment));
}
// Compute the resharding cost vector from multiple possible strategies
// to a desired sharding spec.
std::vector<double> ReshardingCostVector(
const StrategyVector* strategies, const Shape& shape,
const HloSharding& required_sharding,
const ClusterEnvironment& cluster_env) {
// Only works with strategy vector
CHECK(!strategies->is_tuple);
std::vector<double> ret;
for (const auto& x : strategies->leaf_vector) {
ret.push_back(cluster_env.ReshardingCost(shape, x.output_sharding,
required_sharding));
}
return ret;
}
// Create the resharding cost vector for a follow strategy.
std::vector<double> FollowInsCostVector(int64_t source_len, int64_t index) {
std::vector<double> ret(source_len, INFINITY_COST);
ret[index] = 0;
return ret;
}
// Factory functions for StrategyVector.
std::unique_ptr<StrategyVector> CreateLeafStrategyVector(
size_t instruction_id, const HloInstruction* ins,
const StrategyMap& strategy_map, LeafStrategies& leaf_strategies) {
std::unique_ptr<StrategyVector> strategies =
absl::make_unique<StrategyVector>();
strategies->is_tuple = false;
strategies->id = leaf_strategies.size();
leaf_strategies.push_back(strategies.get());
strategies->instruction_id = instruction_id;
for (int64_t i = 0; i < ins->operand_count(); ++i) {
strategies->in_nodes.push_back(strategy_map.at(ins->operand(i)).get());
}
return strategies;
}
std::unique_ptr<StrategyVector> CreateTupleStrategyVector(
size_t instruction_id) {
std::unique_ptr<StrategyVector> strategies =
absl::make_unique<StrategyVector>();
strategies->is_tuple = true;
strategies->id = -1;
strategies->instruction_id = instruction_id;
return strategies;
}
std::unique_ptr<StrategyVector> FollowInsStrategyVector(
const StrategyVector* src_strategies, const Shape& shape,
size_t instruction_id, bool have_memory_cost,
LeafStrategies& leaf_strategies) {
std::unique_ptr<StrategyVector> strategies;
if (src_strategies->is_tuple) {
CHECK(shape.IsTuple());
CHECK_EQ(shape.tuple_shapes_size(), src_strategies->childs.size());
strategies = CreateTupleStrategyVector(instruction_id);
strategies->childs.reserve(src_strategies->childs.size());
for (size_t i = 0; i < src_strategies->childs.size(); ++i) {
strategies->childs.push_back(FollowInsStrategyVector(
src_strategies->childs[i].get(), shape.tuple_shapes(i),
instruction_id, have_memory_cost, leaf_strategies));
}
} else {
CHECK(shape.IsArray());
strategies = absl::make_unique<StrategyVector>();
strategies->is_tuple = false;
strategies->id = leaf_strategies.size();
leaf_strategies.push_back(strategies.get());
strategies->instruction_id = instruction_id;
strategies->in_nodes.push_back(src_strategies);
strategies->following = src_strategies;
strategies->leaf_vector.reserve(src_strategies->leaf_vector.size());
for (int64_t sid = 0; sid < src_strategies->leaf_vector.size(); ++sid) {
HloSharding output_spec =
src_strategies->leaf_vector[sid].output_sharding;
std::string name = ToStringSimple(output_spec);
double compute_cost = 0, communication_cost = 0;
double memory_cost =
have_memory_cost ? GetBytes(shape) / output_spec.NumTiles() : 0;
std::vector<std::vector<double>> resharding_costs = {
FollowInsCostVector(src_strategies->leaf_vector.size(), sid)};
strategies->leaf_vector.push_back(
ShardingStrategy({name,
output_spec,
compute_cost,
communication_cost,
memory_cost,
std::move(resharding_costs),
{}}));
}
}
return strategies;
}
// Add "Replicate()" strategy
void AddReplicatedStrategy(const HloInstruction* ins,
const ClusterEnvironment& cluster_env,
const StrategyMap& strategy_map,
std::unique_ptr<StrategyVector>& strategies,
double replicated_penalty) {
HloSharding output_spec = HloSharding::Replicate();
std::vector<std::vector<double>> resharding_costs;
for (int64_t k = 0; k < ins->operand_count(); ++k) {
resharding_costs.push_back(ReshardingCostVector(
strategy_map.at(ins->operand(k)).get(), ins->operand(k)->shape(),
output_spec, cluster_env));
}
strategies->leaf_vector.push_back(
ShardingStrategy({"R",
HloSharding::Replicate(),
replicated_penalty,
0,
GetBytes(ins->shape()),
std::move(resharding_costs),
{}}));
}
// Enumerate all 1d partition strategies.
void EnumerateAll1DPartition(const HloInstruction* ins,
const Array<int64_t>& device_mesh,
const ClusterEnvironment& cluster_env,
const StrategyMap& strategy_map,
std::unique_ptr<StrategyVector>& strategies,
bool only_allow_divisible,
const std::string& suffix) {
// Split one dim
for (int64_t i = 0; i < ins->shape().rank(); ++i) {
for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) {
if (device_mesh.dim(j) == 1 ||
ins->shape().dimensions(i) < device_mesh.dim(j)) {
continue;
}
if (only_allow_divisible &&
ins->shape().dimensions(i) % device_mesh.dim(j) != 0) {
continue;
}
std::string name = absl::StrFormat("S%d @ %d", i, j) + suffix;
HloSharding output_spec = Tile(ins->shape(), {i}, {j}, device_mesh);
double compute_cost = 0, communication_cost = 0;
double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles();
std::vector<std::vector<double>> resharding_costs;
for (int64_t k = 0; k < ins->operand_count(); ++k) {
const HloInstruction* operand = ins->operand(k);
if (operand->shape().rank() == 0) {
resharding_costs.push_back(std::vector<double>(
strategy_map.at(operand).get()->leaf_vector.size(), 0.0));
} else {
resharding_costs.push_back(ReshardingCostVector(
strategy_map.at(operand).get(), ins->operand(k)->shape(),
output_spec, cluster_env));
}
}
strategies->leaf_vector.push_back(
ShardingStrategy({name,
output_spec,
compute_cost,
communication_cost,
memory_cost,
std::move(resharding_costs),
{}}));
}
}
}
// Enumerate 2D partition
void EnumerateAll2DPartition(const HloInstruction* ins,
const Array<int64_t>& device_mesh,
const ClusterEnvironment& cluster_env,
const StrategyMap& strategy_map,
std::unique_ptr<StrategyVector>& strategies,
bool only_allow_divisible) {
// Fully tile the buffer to 2-d mesh
for (int64_t i = 0; i < ins->shape().rank(); ++i) {
for (int64_t j = 0; j < ins->shape().rank(); ++j) {
if (i == j) {
continue;
}
if (ins->shape().dimensions(i) < device_mesh.dim(0) ||
ins->shape().dimensions(j) < device_mesh.dim(1)) {
continue;
}
if (only_allow_divisible &&
(ins->shape().dimensions(i) % device_mesh.dim(0) != 0 ||
ins->shape().dimensions(j) % device_mesh.dim(1) != 0)) {
continue;
}
std::string name = absl::StrFormat("S{%d,%d} @ {0,1}", i, j);
HloSharding output_spec = Tile(ins->shape(), {i, j}, {0, 1}, device_mesh);
double compute_cost = 0, communication_cost = 0;
double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles();
std::vector<std::vector<double>> resharding_costs;
for (int64_t k = 0; k < ins->operand_count(); ++k) {
const HloInstruction* operand = ins->operand(k);
if (operand->shape().rank() == 0) {
resharding_costs.push_back(std::vector<double>(
strategy_map.at(operand).get()->leaf_vector.size(), 0.0));
} else {
resharding_costs.push_back(
ReshardingCostVector(strategy_map.at(operand).get(),
operand->shape(), output_spec, cluster_env));
}
}
strategies->leaf_vector.push_back(
ShardingStrategy({name,
output_spec,
compute_cost,
communication_cost,
memory_cost,
std::move(resharding_costs),
{}}));
}
}
}
// Enumerate all 1d partition strategies.
void EnumerateAll1DPartitionReshape(const HloInstruction* ins,
const Array<int64_t>& device_mesh,
const ClusterEnvironment& cluster_env,
const StrategyMap& strategy_map,
std::unique_ptr<StrategyVector>& strategies,
const std::string& suffix) {
const HloInstruction* operand = ins->operand(0);
for (int64_t i = 0; i < ins->shape().rank(); ++i) {
for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) {
if (device_mesh.dim(j) == 1 ||
ins->shape().dimensions(i) < device_mesh.dim(j)) {
continue;
}
HloSharding output_spec = Tile(ins->shape(), {i}, {j}, device_mesh);
std::optional<HloSharding> input_spec =
hlo_sharding_util::ReshapeSharding(ins->shape(), operand->shape(),
output_spec);
if (!input_spec.has_value()) { // invalid reshape
continue;
}
std::string name = absl::StrFormat("S%d @ %d", i, j) + suffix;
double compute_cost = 0, communication_cost = 0;
double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles();
std::vector<std::vector<double>> resharding_costs{
ReshardingCostVector(strategy_map.at(operand).get(), operand->shape(),
*input_spec, cluster_env)};
strategies->leaf_vector.push_back(
ShardingStrategy({name,
output_spec,
compute_cost,
communication_cost,
memory_cost,
std::move(resharding_costs),
{*input_spec}}));
}
}
}
// Enumerate 2D partition for reshape. Batch dim is always partitioned.
void Enumerate2DPartitionReshape(const HloInstruction* ins,
const Array<int64_t>& device_mesh,
const ClusterEnvironment& cluster_env,
const StrategyMap& strategy_map,
const InstructionBatchDimMap& batch_dim_map,
std::unique_ptr<StrategyVector>& strategies) {
auto iter = batch_dim_map.find(ins);
if (iter == batch_dim_map.end()) {
return;
}
int batch_dim = iter->second;
const HloInstruction* operand = ins->operand(0);
// Split batch dim + another dim
for (int64_t i = 0; i < ins->shape().rank(); ++i) {
for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) {
if (device_mesh.dim(j) == 1 ||
ins->shape().dimensions(i) < device_mesh.dim(j)) {
continue;
}
if (batch_dim == i || 0 == j) {
continue;
}
HloSharding output_spec =
Tile(ins->shape(), {batch_dim, i}, {0, j}, device_mesh);
std::optional<HloSharding> input_spec =
hlo_sharding_util::ReshapeSharding(ins->shape(), operand->shape(),
output_spec);
if (!input_spec.has_value()) { // invalid reshape
continue;
}
std::string name = absl::StrFormat("S%d%d @ {%d,%d}", batch_dim, i, 0, j);
double compute_cost = 0, communication_cost = 0;
double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles();
std::vector<std::vector<double>> resharding_costs{
ReshardingCostVector(strategy_map.at(operand).get(), operand->shape(),
*input_spec, cluster_env)};
strategies->leaf_vector.push_back(
ShardingStrategy({name,
output_spec,
compute_cost,
communication_cost,
memory_cost,
std::move(resharding_costs),
{*input_spec}}));
}
}
}
// Return the maximum number of tiles among all strategies of an instruction.
int64_t MaxNumTiles(const StrategyMap& strategy_map,
const HloInstruction* ins) {
const StrategyVector* strategies = strategy_map.at(ins).get();
// TODO(lmzheng): optimize with path compression.
while (strategies->following != nullptr) {
strategies = strategies->following;
}
int64_t max_num_tiles = -1;
std::function<void(const StrategyVector*)> visit_all;
visit_all = [&](const StrategyVector* stra) {
if (stra->is_tuple) {
for (const auto& child : stra->childs) {
visit_all(child.get());
}
} else {
for (size_t i = 0; i < stra->leaf_vector.size(); ++i) {
max_num_tiles = std::max(
max_num_tiles, stra->leaf_vector[i].output_sharding.NumTiles());
}
}
};
visit_all(strategies);
return max_num_tiles;
}
// Choose an operand to follow.
// We choose to follow the operand with the highest priority.
// priority(operand) = max(x.output_spec.num_tiles for x in operand.strategies)
//
// Return `tie == True` if there are two operands with very close priorities and
// we cannot decide which one to follow.
std::pair<int64_t, bool> ChooseOperandToFollow(
const StrategyMap& strategy_map, const InstructionDepthMap& depth_map,
const AliasMap& alias_map,
const absl::flat_hash_set<const HloInstruction*>& undefined_set,
int64_t max_depth, const HloInstruction* ins) {
int64_t follow_idx = -1;
bool tie = false;
double max_priority = -1e20;
double depth_normalizer = 0.1 / max_depth;
double range_delta = 4 * depth_normalizer;
for (int64_t i = 0; i < ins->operand_count(); ++i) {
const HloInstruction* operand = ins->operand(i);
if (!undefined_set.count(operand)) {
double priority = MaxNumTiles(strategy_map, operand) +
depth_map.at(operand) * depth_normalizer;
if (priority > max_priority + range_delta) {
follow_idx = i;
tie = false;
max_priority = priority;
} else if (priority >= max_priority - range_delta) {
tie = true;
}
}
// If an alias constraint is set, always follow its alias source.
auto it = alias_map.find(ins);
if (it != alias_map.end() && it->second == operand) {
break;
}
}
CHECK_GE(follow_idx, 0);
return std::make_pair(follow_idx, tie);
}
// Return whether an instruction can follow one of its operand when
// more than one operand have the same priority.
bool AllowTieFollowing(const HloInstruction* ins) {
if (ins->opcode() == HloOpcode::kCompare ||
ins->opcode() == HloOpcode::kAnd) {
// This is used to resolve tricky cases where an iota and a parameter
// has the same priority. This happens for embedding, onehot or
// make_attention_mask.
return false;
}
if (ins->operand_count() == 3) {
return false;
}
return true;
}
// Build possible sharding strategies and their costs for all instructions.
StatusOr<std::tuple<StrategyMap, LeafStrategies, AssociativeDotPairs>>
BuildStrategyAndCost(const HloInstructionSequence& sequence,
const InstructionDepthMap& depth_map,
const InstructionBatchDimMap& batch_dim_map,
const AliasMap& alias_map,
const ClusterEnvironment& cluster_env,
AutoShardingSolverOption& solver_option) {
const Array<int64_t>& device_mesh = cluster_env.device_mesh;
const Array<int64_t>& device_mesh_1d = cluster_env.device_mesh_1d;
StrategyMap strategy_map;
LeafStrategies leaf_strategies;
AssociativeDotPairs associative_dot_pairs;
absl::flat_hash_set<const HloInstruction*> undefined_set;
const std::vector<HloInstruction*>& instructions = sequence.instructions();
// Count the non-one mesh dimension.
int mesh_nn_dims = 0;
for (int dim : device_mesh.dimensions()) {
if (dim > 1) {
mesh_nn_dims++;
}
}
// Gather all output values
absl::flat_hash_set<const HloInstruction*> output_set;
for (size_t i = 0; i < instructions.back()->operand_count(); ++i) {
output_set.insert(instructions.back()->operand(i));
}
// Add penalty for replicated tensors
double replicated_penalty = std::round(cluster_env.AllReduceCost(1, 0) +
cluster_env.AllReduceCost(1, 1));
int64_t max_depth = -1;
for (auto iter : depth_map) {
max_depth = std::max(max_depth, iter.second);
}
int64_t disallowed_follow = 0;
// Register strategies and their costs for each instruction.
for (size_t instruction_id = 0; instruction_id < instructions.size();
++instruction_id) {
const HloInstruction* ins = instructions[instruction_id];
std::unique_ptr<StrategyVector> strategies;
HloOpcode opcode = ins->opcode();
switch (opcode) {
case HloOpcode::kParameter:
case HloOpcode::kRng: {
strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
leaf_strategies);
// Split 1 dim
EnumerateAll1DPartition(ins, device_mesh, cluster_env, strategy_map,
strategies, true, "");
// Split 2 dims
if (cluster_env.non_zero_mesh_dims.size() > 1) {
// Add penalty for 1d partial tiled layout
for (size_t i = 0; i < strategies->leaf_vector.size(); ++i) {
strategies->leaf_vector[i].compute_cost += replicated_penalty * 0.8;
}
if (batch_dim_map.count(ins)) {
// This is a pruning heuristic: only allow 2d partition
// for parameters with a batch dim. These parameters are
// typically input data and intermediate activations.
EnumerateAll2DPartition(ins, device_mesh, cluster_env, strategy_map,
strategies, true);
}
if (solver_option.allow_mixed_mesh_shape) {
// Split 1 dim, but for 1d mesh
EnumerateAll1DPartition(ins, device_mesh_1d, cluster_env,
strategy_map, strategies, true, " 1d");
}
}
if (solver_option.allow_replicated_parameters) {
AddReplicatedStrategy(ins, cluster_env, strategy_map, strategies,
replicated_penalty);
}
RemoveDuplicatedStrategy(strategies);
// If force_batch_dim_to_mesh_dim is set, filter out invalid strategies
// and only keep the data parallel strategies.
if (solver_option.force_batch_dim_to_mesh_dim >= 0 &&
batch_dim_map.count(ins)) {
TF_RETURN_IF_ERROR(FilterStrategy(ins, strategies, cluster_env,
batch_dim_map, solver_option));
}
break;
}
case HloOpcode::kConstant: {
strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
leaf_strategies);
AddReplicatedStrategy(ins, cluster_env, strategy_map, strategies, 0);
break;
}
case HloOpcode::kBroadcast: {
strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
leaf_strategies);
const HloInstruction* operand = ins->operand(0);
if (undefined_set.count(operand)) {
break;
}
// Create follow strategies
const StrategyVector* src_strategies = strategy_map.at(operand).get();
CHECK(!src_strategies->is_tuple);
strategies->following = src_strategies;
for (int64_t sid = 0; sid < src_strategies->leaf_vector.size(); ++sid) {
HloSharding output_spec = BroadcastSharding(
src_strategies->leaf_vector[sid].output_sharding, ins->shape(),
ins->dimensions());
std::string name = ToStringSimple(output_spec);
double compute_cost = 0, communication_cost = 0;
double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles();
strategies->leaf_vector.push_back(ShardingStrategy(
{name,
output_spec,
compute_cost,
communication_cost,
memory_cost,
{FollowInsCostVector(src_strategies->leaf_vector.size(), sid)},
{}}));
}
// If the operand is a scalar, following it only generates "Replicated"
// strategy. So we should register new strategies instead of following
// it.
if (operand->shape().rank() == 0) {
if (!output_set.count(ins) &&
operand->opcode() == HloOpcode::kConstant) {
// one execption: always replicate intermidiate broadcasted
// constants.
break;
}
strategies->following = nullptr;
strategies->leaf_vector.clear();
// Split one dim
for (int64_t i = 0; i < ins->shape().rank(); ++i) {
for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) {
if (device_mesh.dim(j) == 1 ||
ins->shape().dimensions(i) < device_mesh.dim(j)) {
continue;
}
std::string name = absl::StrFormat("S%d @ %d", i, j);
HloSharding output_spec =
Tile(ins->shape(), {i}, {j}, device_mesh);
double compute_cost = 0, communication_cost = 0;
double memory_cost =
GetBytes(ins->shape()) / output_spec.NumTiles();
strategies->leaf_vector.push_back(ShardingStrategy(
{name,
output_spec,
compute_cost,
communication_cost,
memory_cost,
{std::vector<double>(src_strategies->leaf_vector.size(),
0.0)},
{}}));
}
}
// Replicate
AddReplicatedStrategy(ins, cluster_env, strategy_map, strategies,
replicated_penalty);
RemoveDuplicatedStrategy(strategies);
}
break;
}
case HloOpcode::kReshape: {
strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
leaf_strategies);
const HloInstruction* operand = ins->operand(0);
// Create follow strategies
if (!undefined_set.count(operand) &&
((ins->users().size() == 1 && !IsBatchDimSwitchReshape(ins)) ||
(mesh_nn_dims >= 2 && !solver_option.allow_mixed_mesh_shape))) {
const StrategyVector* src_strategies = strategy_map.at(operand).get();
CHECK(!src_strategies->is_tuple);
strategies->following = src_strategies;
for (int64_t sid = 0; sid < src_strategies->leaf_vector.size();
++sid) {
std::optional<HloSharding> output_spec =
hlo_sharding_util::ReshapeSharding(
operand->shape(), ins->shape(),
src_strategies->leaf_vector[sid].output_sharding);
if (!output_spec.has_value()) {
continue;
}
if (!IsValidTileAssignment(*output_spec)) {
continue;
}
std::string name = ToStringSimple(*output_spec);
double compute_cost = 0, communication_cost = 0;
double memory_cost =
GetBytes(ins->shape()) / output_spec->NumTiles();
strategies->leaf_vector.push_back(ShardingStrategy(
{name,
*output_spec,
compute_cost,
communication_cost,
memory_cost,
{FollowInsCostVector(src_strategies->leaf_vector.size(), sid)},
{}}));
}
}
// Fail to create follow strategies, enumerate all possible cases
if (strategies->leaf_vector.empty()) {
strategies->leaf_vector.clear();
strategies->following = nullptr;
// Split 1 dim
EnumerateAll1DPartitionReshape(ins, device_mesh, cluster_env,
strategy_map, strategies, "");
if (solver_option.allow_mixed_mesh_shape &&
cluster_env.non_zero_mesh_dims.size() > 1) {
// Split 1 dim, but for 1d mesh
EnumerateAll1DPartitionReshape(ins, device_mesh_1d, cluster_env,
strategy_map, strategies, " 1d");
// Split 2 dim, one is always the batch dim
Enumerate2DPartitionReshape(ins, device_mesh, cluster_env,
strategy_map, batch_dim_map,
strategies);
}
// Replicate
AddReplicatedStrategy(ins, cluster_env, strategy_map, strategies,
replicated_penalty);
RemoveDuplicatedStrategy(strategies);
}
break;
}
case HloOpcode::kTranspose:
case HloOpcode::kReverse: {
strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
leaf_strategies);
const HloInstruction* operand = ins->operand(0);
if (undefined_set.count(operand)) {
break;
}
// Create follow strategies
const StrategyVector* src_strategies = strategy_map.at(operand).get();
CHECK(!src_strategies->is_tuple);
strategies->following = src_strategies;
for (int64_t sid = 0; sid < src_strategies->leaf_vector.size(); ++sid) {
HloSharding output_spec = Undefined();
if (opcode == HloOpcode::kTranspose) {
output_spec = hlo_sharding_util::TransposeSharding(
src_strategies->leaf_vector[sid].output_sharding,
ins->dimensions());
} else {
output_spec = hlo_sharding_util::ReverseSharding(
src_strategies->leaf_vector[sid].output_sharding,
ins->dimensions());
}
std::string name = ToStringSimple(output_spec);
double compute_cost = 0, communication_cost = 0;
double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles();
strategies->leaf_vector.push_back(ShardingStrategy(
{name,
output_spec,
compute_cost,
communication_cost,
memory_cost,
{FollowInsCostVector(src_strategies->leaf_vector.size(), sid)},
{}}));
}
break;
}
case HloOpcode::kPad:
case HloOpcode::kSlice:
case HloOpcode::kConcatenate: // TODO(lmzheng): revisit concatenate
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter: {
strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
leaf_strategies);
// Choose an operand to follow
int64_t follow_idx;
bool tie;
std::tie(follow_idx, tie) = ChooseOperandToFollow(
strategy_map, depth_map, alias_map, undefined_set, max_depth, ins);
// Create follow strategies
const HloInstruction* operand = ins->operand(follow_idx);
const StrategyVector* src_strategies = strategy_map.at(operand).get();
CHECK(!src_strategies->is_tuple);
strategies->following = src_strategies;
for (int64_t sid = 0; sid < src_strategies->leaf_vector.size(); ++sid) {
std::optional<HloSharding> output_spec;
switch (opcode) {
case HloOpcode::kPad:
case HloOpcode::kSlice:
case HloOpcode::kConcatenate:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
output_spec = PropagateDimwiseSharding(
src_strategies->leaf_vector[sid].output_sharding,
operand->shape(), ins->shape());
break;
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
output_spec = PropagateReduceWindowSharding(
src_strategies->leaf_vector[sid].output_sharding,
operand->shape(), ins->window());
break;
default:
LOG(FATAL) << "Unhandled instruction: " + ins->ToString();
}
if (!output_spec.has_value()) {
continue;
}
std::string name = ToStringSimple(*output_spec);
double compute_cost = 0, communication_cost = 0;
double memory_cost = GetBytes(ins->shape()) / output_spec->NumTiles();
std::vector<std::vector<double>> resharding_costs;
for (int64_t k = 0; k < ins->operand_count(); ++k) {
if (k == follow_idx) {
resharding_costs.push_back(
FollowInsCostVector(src_strategies->leaf_vector.size(), sid));
} else {
operand = ins->operand(k);
if (operand->shape().rank() > 0) {
resharding_costs.push_back(ReshardingCostVector(
strategy_map.at(operand).get(), operand->shape(),
*output_spec, cluster_env));
} else {
resharding_costs.push_back(std::vector<double>(
strategy_map.at(operand)->leaf_vector.size(), 0.0));
}
}
}
strategies->leaf_vector.push_back(
ShardingStrategy({name,
*output_spec,
compute_cost,
communication_cost,
memory_cost,
std::move(resharding_costs),
{}}));
}
break;
}
case HloOpcode::kGather: {
strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
leaf_strategies);
auto dnums = ins->gather_dimension_numbers();
// Split one update_window_dims
for (size_t i = 0; i < dnums.offset_dims().size(); ++i) {
for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) {
if (device_mesh.dim(j) == 1 ||
ins->shape().dimensions(i) < device_mesh.dim(j)) {
continue;
}
HloSharding output_spec = Tile(ins->shape(), {i}, {j}, device_mesh);
int operand_dim = dnums.offset_dims(i);
CHECK_LT(operand_dim, ins->operand(0)->shape().rank())
<< "Does not support this kind of Gather.";
CHECK_EQ(ins->shape().dimensions(operand_dim),
ins->operand(0)->shape().dimensions(operand_dim))
<< "Does not support this kind of Gather.";
std::vector<HloSharding> operand_specs{
Tile(ins->operand(0)->shape(), {operand_dim}, {j}, device_mesh),
HloSharding::Replicate(),
};
std::string name = ToStringSimple(output_spec);
double compute_cost = 0, communication_cost = 0;
double memory_cost =
GetBytes(ins->shape()) / output_spec.NumTiles();
std::vector<std::vector<double>> resharding_costs;
for (int64_t k = 0; k < ins->operand_count(); ++k) {
resharding_costs.push_back(ReshardingCostVector(
strategy_map.at(ins->operand(k)).get(),
ins->operand(k)->shape(), operand_specs[k], cluster_env));
}
strategies->leaf_vector.push_back(ShardingStrategy(
{name, output_spec, compute_cost, communication_cost,
memory_cost, std::move(resharding_costs),
std::move(operand_specs)}));
}
}
// Replicate all
AddReplicatedStrategy(ins, cluster_env, strategy_map, strategies, 0);
break;
}
case HloOpcode::kScatter: {
strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
leaf_strategies);
auto dnums = ins->scatter_dimension_numbers();
// Split one update_window_dims
for (size_t i = 0; i < dnums.update_window_dims().size(); ++i) {
for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) {
if (device_mesh.dim(j) == 1 ||
ins->shape().dimensions(i) < device_mesh.dim(j)) {
continue;
}
HloSharding output_spec = Tile(ins->shape(), {i}, {j}, device_mesh);
int operand_dim = dnums.update_window_dims(i);
int update_dim = operand_dim;
CHECK_EQ(ins->shape().dimensions(operand_dim),
ins->operand(0)->shape().dimensions(operand_dim));
CHECK_EQ(ins->shape().dimensions(operand_dim),
ins->operand(2)->shape().dimensions(update_dim));
std::vector<HloSharding> operand_specs{
Tile(ins->operand(0)->shape(), {operand_dim}, {j}, device_mesh),
HloSharding::Replicate(),
Tile(ins->operand(2)->shape(), {update_dim}, {j}, device_mesh),
};
std::string name = ToStringSimple(output_spec);
double compute_cost = 0, communication_cost = 0;
double memory_cost =
GetBytes(ins->shape()) / output_spec.NumTiles();
std::vector<std::vector<double>> resharding_costs;
for (int64_t k = 0; k < ins->operand_count(); ++k) {
resharding_costs.push_back(ReshardingCostVector(
strategy_map.at(ins->operand(k)).get(),
ins->operand(k)->shape(), operand_specs[k], cluster_env));
}
strategies->leaf_vector.push_back(ShardingStrategy(
{name, output_spec, compute_cost, communication_cost,
memory_cost, std::move(resharding_costs),
std::move(operand_specs)}));
}
}
// Replicate all
AddReplicatedStrategy(ins, cluster_env, strategy_map, strategies, 0);
break;
}
// Unary elementwise operations.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kConvert:
case HloOpcode::kBitcast:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kReducePrecision:
case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
// Binary elementwise operations
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kCompare:
case HloOpcode::kComplex: