-
Notifications
You must be signed in to change notification settings - Fork 75
Labels
Description
For example, it should decomposes the linear in
TensorView* in = makeContigConcreteTensor({-1, -1, h_i}); // [i{b}, i{s}, i{h_i}]
TensorView* weight = makeContigConcreteTensor({h_o, h_i}); // [i{h_o}, i{h_i}]
TensorView* out = linear(in, weight, /*bias=*/nullptr); // [i{b}, i{s}, i{h_o}, r{h_i}]
for (auto* tv : {in, weight, out}) {
tv->setDeviceMesh(mesh);
}
in->outer_split(-1, d); // split h_i by num of devices
in->axis(-2)->parallelize(ParallelType::DIDx);
weight->outer_split(-1, d);
weight->axis(-2)->parallelize(ParallelType::DIDx);
out->outer_split(1, d); // split s by num of devices, i.e., sequence parallel
out->axis(1)->parallelize(ParallelType::DIDx);into a local linear followed by a ReduceScatter. It's going to be similar to this example but exactly.
The above is one layout. See/run
to collect more use cases.