Skip to content

InsertReshardingsPass decomposes matmul/linear+allreduce. #4133

@wujingyue

Description

@wujingyue

For example, it should decomposes the linear in

TensorView* in = makeContigConcreteTensor({-1, -1, h_i});  // [i{b}, i{s}, i{h_i}]
TensorView* weight = makeContigConcreteTensor({h_o, h_i});  // [i{h_o}, i{h_i}]
TensorView* out = linear(in, weight, /*bias=*/nullptr);  // [i{b}, i{s}, i{h_o}, r{h_i}]

for (auto* tv : {in, weight, out}) {
  tv->setDeviceMesh(mesh);
}
in->outer_split(-1, d);  // split h_i by num of devices
in->axis(-2)->parallelize(ParallelType::DIDx);
weight->outer_split(-1, d);
weight->axis(-2)->parallelize(ParallelType::DIDx);
out->outer_split(1, d);  // split s by num of devices, i.e., sequence parallel
out->axis(1)->parallelize(ParallelType::DIDx);

into a local linear followed by a ReduceScatter. It's going to be similar to this example but exactly.

The above is one layout. See/run

to collect more use cases.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions