Skip to content

Commit

Permalink
[Unity][Transform] Handle dynamic shapes in CombineParallelMatmul (#1…
Browse files Browse the repository at this point in the history
…6591)

* [Unity][Transform] Handle dynamic shapes in CombineParallelMatmul

Prior to this commit, if the weight of a matmul a dynamic shape, and that
matmul is being combined with the `CombineParallelMatmul` pass, it
could cause a segfault when `dim.as<IntImmNode>()` returns a null
pointer.

This commit adds explicit test cases for these dynamic shapes, and
updates `CombineParallelMatmul` to handle the dynamic shapes.

* Add Tuple constructor for PR-16589
  • Loading branch information
Lunderberg committed Feb 23, 2024
1 parent 864fd5c commit 89cc09c
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 61 deletions.
18 changes: 18 additions & 0 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,24 @@ class Tuple : public Expr {
*/
TVM_DLL explicit Tuple(tvm::Array<Expr> fields, Span span = Span());

/*!
* \brief Utility constructor to handle conversion to relax::Expr
*
* If the calling scope already has an array of a specific type of
* relax expression (e.g. `Array<relax::Var>`), it must be converted
* into an array of base type. This constructor handles the
* conversion to the base `Array<relax::Expr>`.
*
* \tparam RelaxExpr The type of relax expression passed in as an argument.
*
* \param fields The fields of a tuple.
*
* \param span The source span of the expression.
*/
template <typename RelaxExpr, typename = std::enable_if_t<std::is_base_of_v<Expr, RelaxExpr>>>
TVM_DLL explicit Tuple(tvm::Array<RelaxExpr> fields, Span span = Span())
: Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {}

TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
};
Expand Down
160 changes: 100 additions & 60 deletions src/relax/transform/combine_parallel_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,16 @@ struct Patterns {
WildcardPattern input;
std::vector<WildcardPattern> rhs;
std::vector<WildcardPattern> bias;
std::vector<CallPattern> matmul, bias_add, activation;
std::vector<CallPattern> matmul;
std::vector<CallPattern> bias_add;
std::vector<CallPattern> activation;
};

struct SplitInfo {
Var rhs;
Optional<Var> bias;
PrimExpr split_size;
DFPattern pattern_to_replace;
};

Patterns CreatePatterns(const BranchInfo& branch_info) {
Expand Down Expand Up @@ -140,40 +149,68 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> Ge
for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) {
if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices, rhs_shapes)) continue;

auto inp = matchings[patterns.input];
auto lhs = matchings[patterns.input];

const auto& patterns_to_replace = [&patterns, &branch_info]() {
if (branch_info.activation) return patterns.activation;
if (branch_info.bias_dim) return patterns.bias_add;
return patterns.matmul;
}();

Array<Var> rhs, bias;
for (auto ind : indices) {
rhs.push_back(matchings[patterns.rhs[ind]]);
if (branch_info.bias_dim) {
ICHECK(matchings.count(patterns.bias[ind]));
bias.push_back(matchings[patterns.bias[ind]]);
std::vector<SplitInfo> splits;
for (auto index : indices) {
Var rhs = matchings[patterns.rhs[index]];
Optional<Var> bias = NullOpt;
if (branch_info.bias_dim.has_value()) {
bias = matchings[patterns.bias[index]];
}
PrimExpr split_size = GetTensorSInfo(rhs)->GetShape().value()[rhs_dim - 1];
DFPattern pattern_to_replace = patterns_to_replace[index];
splits.push_back(SplitInfo{rhs, bias, split_size, pattern_to_replace});
}
// At most one dynamic output shape can be part of the combined
// matmul, and it must be the last item in the split. Use
// `std::stable_sort` instead of `std::sort` to maintain a
// consistent order for all static shapes, and to consistently
// select the same dynamic weight to participate.
auto is_dynamic_split = [](const SplitInfo& split) -> bool {
return !split.split_size->IsInstance<IntImmNode>();
};
std::stable_sort(splits.begin(), splits.end(),
[&is_dynamic_split](const auto& a, const auto& b) {
return is_dynamic_split(a) < is_dynamic_split(b);
});
// Remove anything after the first dynamic shape participating
// in the combined matmul.
if (auto it = std::find_if(splits.begin(), splits.end(), is_dynamic_split);
it != splits.end()) {
splits.erase(it + 1, splits.end());
}

if (!check(inp, rhs, bias, bindings)) {
if (splits.size() == 1) {
continue;
}

auto make_tuple = [](const Array<Var>& var_array) {
Array<Expr> exp_array;
for (auto v : var_array) exp_array.push_back(v);
return Tuple(exp_array);
};
Array<Var> rhs;
Array<Var> bias;
for (const auto& split : splits) {
rhs.push_back(split.rhs);
if (split.bias) {
bias.push_back(split.bias.value());
}
}

auto concat_rhs = concat(make_tuple(rhs), Integer(rhs_dim - 1));
auto out_dtype = GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
auto matmul_combined = matmul(inp, concat_rhs, out_dtype);
if (!check(lhs, rhs, bias, bindings)) {
continue;
}

const auto& pattern_to_replace = [&patterns, &branch_info]() {
if (branch_info.activation) return patterns.activation;
if (branch_info.bias_dim) return patterns.bias_add;
return patterns.matmul;
}();
auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1));
auto out_dtype = GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
auto matmul_combined = matmul(lhs, concat_rhs, out_dtype);

if (branch_info.bias_dim) {
auto bias_dim = GetTensorSInfo(bias[0])->ndim;
auto concat_bias = concat(make_tuple(bias), Integer(bias_dim - 1));
auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1));
matmul_combined = add(matmul_combined, concat_bias);
}

Expand All @@ -191,20 +228,23 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> Ge
}
}

int ind = 0;
int split_index = 0;
Array<IntImm> sections;
for (int i = 0; i < static_cast<int>(indices.size()) - 1; ++i) {
auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1].as<IntImmNode>();
ind += width->value;
sections.push_back(IntImm(DataType::Int(64), ind));
for (size_t i = 0; i + 1 < splits.size(); i++) {
auto width = splits[i].split_size.as<IntImmNode>();
ICHECK(width) << "InternalError: "
<< "All splits except the last one must have a static shape";
split_index += width->value;
sections.push_back(IntImm(DataType::Int(64), split_index));
}

int lhs_dim = GetTensorSInfo(inp)->ndim;
int lhs_dim = GetTensorSInfo(lhs)->ndim;
int split_axis = std::max<int>(lhs_dim, rhs_dim) - 1;
auto chunks = split(matmul_combined, sections, split_axis);

for (size_t i = 0; i < indices.size(); ++i) {
auto bound_var = matchings[pattern_to_replace[indices[i]]];
for (size_t i = 0; i < splits.size(); i++) {
const auto& split = splits[i];
auto bound_var = matchings[split.pattern_to_replace];
replacements.Set(bound_var, TupleGetItem(chunks, i));
}
}
Expand Down Expand Up @@ -244,43 +284,43 @@ std::vector<BranchInfo> GetBranchInfo(Function f) {

PostOrderVisit(f, [&](const Expr& e) {
if (!e->IsInstance<CallNode>()) return;
if (auto match = ExtractMatchedExpr(pat, e, bindings)) {
auto matmul_call = Downcast<Call>(match.value()[matmul_pat]);
auto matmul_lhs = Downcast<Var>(matmul_call->args[0]);

auto it = groups.find(matmul_lhs.get());
BranchInfo* branch = it != groups.end() ? &it->second : nullptr;
std::optional<int> bias_dim = std::nullopt;
std::optional<std::string> activation = std::nullopt;
auto match = ExtractMatchedExpr(pat, e, bindings);
if (!match) return;

if (match.value().count(bias_pat)) {
bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim;
}
auto matmul_call = Downcast<Call>(match.value()[matmul_pat]);
auto matmul_lhs = Downcast<Var>(matmul_call->args[0]);

for (size_t i = 0; i < activations.size(); ++i) {
if (match.value().count(activation_pat[i]) ||
match.value().count(bias_activation_pat[i])) {
activation = activations[i];
}
std::optional<int> bias_dim = std::nullopt;
std::optional<std::string> activation = std::nullopt;

if (match.value().count(bias_pat)) {
bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim;
}

for (size_t i = 0; i < activations.size(); ++i) {
if (match.value().count(activation_pat[i]) || match.value().count(bias_activation_pat[i])) {
activation = activations[i];
}
}

if (!branch) {
// Create a new subgraph with one matmul
groups[matmul_lhs.get()] = {1, bias_dim, activation};
} else {
// Create a new branch in the existing parallel matmul subtree, and
// invalidate bias and activation information when needed.
branch->num_branches += 1;
if (auto it = groups.find(matmul_lhs.get()); it != groups.end()) {
// Create a new branch in the existing parallel matmul subtree, and
// invalidate bias and activation information when needed.
BranchInfo* branch = &it->second;

branch->num_branches += 1;

if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) {
branch->bias_dim = std::nullopt;
}
if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) {
branch->bias_dim = std::nullopt;
}

if (!activation || (branch->activation && *branch->activation != *activation)) {
branch->activation = std::nullopt;
}
if (!activation || (branch->activation && *branch->activation != *activation)) {
branch->activation = std::nullopt;
}
return;
} else {
// Create a new subgraph with one matmul
groups[matmul_lhs.get()] = {1, bias_dim, activation};
}
});

Expand Down
123 changes: 122 additions & 1 deletion tests/python/relax/test_transform_combine_parallel_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,16 @@ def expected(
tvm.ir.assert_structural_equal(after, expected)


def test_dynamic_rhs():
def test_combine_matmul_of_static_and_dynamic_shapes():
"""Combine two matmuls, one with dynamic shape
The `R.split` operator must have a static list of integer indices
at which to split the matmul output, because these integer indices
are stored as operator attributes. However, the last output can
still have a dynamic shape.
"""

@R.function(private=True)
def before(
x: R.Tensor((2, 1024, 640), "float32"),
Expand Down Expand Up @@ -572,5 +581,117 @@ def expected(
tvm.ir.assert_structural_equal(after, expected)


def test_combine_matmul_of_dynamic_and_static_shapes():
"""Combine two matmuls, one with dynamic shape
Like `test_combine_matmul_of_static_and_dynamic_shapes`, but the
dynamic-shaped matmul is encountered first. Due to the
requirements imposed by `R.split` storing the split indices as
static integers, the static-shaped weights must occur first in the
concatenated weights.
"""

@R.function(private=True)
def before(
x: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, "M"), "float32"),
w1: R.Tensor((640, 640), "float32"),
):
M = T.int64()
with R.dataflow():
lv0 = R.matmul(x, w0)
lv1 = R.matmul(x, w1)
out = (lv0, lv1)
R.output(out)
return out

@R.function(private=True)
def expected(
x: R.Tensor((2, 1024, 640), dtype="float32"),
w0: R.Tensor((640, "M"), dtype="float32"),
w1: R.Tensor((640, 640), dtype="float32"),
) -> R.Tuple(
R.Tensor((2, 1024, "M"), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")
):
M = T.int64()
with R.dataflow():
lv: R.Tensor((640, 640 + M), dtype="float32") = R.concat((w1, w0), axis=1)
lv1: R.Tensor((2, 1024, 640 + M), dtype="float32") = R.matmul(
x, lv, out_dtype="float32"
)
lv2: R.Tuple(
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, M), dtype="float32"),
) = R.split(lv1, indices_or_sections=[640], axis=2)
lv0: R.Tensor((2, 1024, M), dtype="float32") = lv2[1]
lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv2[0]
out: R.Tuple(
R.Tensor((2, 1024, M), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
) = (lv0, lv1_1)
R.output(out)
return out

after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]

tvm.ir.assert_structural_equal(after, expected)


def test_limit_one_dynamic_shape_in_combined_matmul():
"""Combine two matmuls, one with dynamic shape
Like `test_combine_matmul_of_static_and_dynamic_shapes`, but with
two dynamic weights that could, in principle, be merged together.
Because `R.split` must have integer indices at which to split,
only one of the dynamic outputs can be part of the combined
matmul.
"""

@R.function(private=True)
def before(
x: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, "M"), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, "N"), "float32"),
):
M = T.int64()
with R.dataflow():
lv0 = R.matmul(x, w0)
lv1 = R.matmul(x, w1)
lv2 = R.matmul(x, w2)
out = (lv0, lv1, lv2)
R.output(out)
return out

@R.function(private=True)
def expected(
x: R.Tensor((2, 1024, 640), dtype="float32"),
w0: R.Tensor((640, "M"), dtype="float32"),
w1: R.Tensor((640, 640), dtype="float32"),
w2: R.Tensor((640, "N"), "float32"),
) -> R.Tuple(
R.Tensor((2, 1024, "M"), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, "N"), dtype="float32"),
):
M = T.int64()
with R.dataflow():
concat_weights = R.concat((w1, w0), axis=1)
concat_output = R.matmul(x, concat_weights, out_dtype="float32")
split_output: R.Tuple(
[R.Tensor([2, 1024, 640], dtype="float32"), R.Tensor([2, 1024, M], dtype="float32")]
) = R.split(concat_output, indices_or_sections=[640], axis=2)
lv0 = split_output[1]
lv1 = split_output[0]
lv2 = R.matmul(x, w2)
out = (lv0, lv1, lv2)
R.output(out)
return out

after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]

tvm.ir.assert_structural_equal(after, expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 89cc09c

Please sign in to comment.