Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Transform] Canonicalize let var = R.const bindings #16601

Merged
merged 2 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct CanonicalizationPlan {
Map<Id, Var> replace_usage;
Map<Id, Var> replace_binding;
std::unordered_set<Id, ObjectPtrHash, ObjectPtrEqual> bindings_to_remove;
Map<Id, Constant> inline_constant;
};

/*! \brief Utility class to identify usage location
Expand Down Expand Up @@ -69,6 +70,10 @@ class CanonicalizePlanner : public ExprVisitor {
}
}

for (const auto& [var, constant] : visitor.known_bound_to_constant_) {
plan.inline_constant.Set(var->vid, constant);
}

for (const auto& binding_iter : visitor.trivial_bindings_) {
Var bound_var = binding_iter.first;
Var bound_to = binding_iter.second;
Expand Down Expand Up @@ -180,6 +185,10 @@ class CanonicalizePlanner : public ExprVisitor {
trivial_bindings_.Set(binding->var, parent.value());
}

if (auto constant = value.as<Constant>()) {
known_bound_to_constant_.Set(binding->var, constant.value());
}

known_bindings_.Set(binding->var, value);
def_blocks_.Set(binding->var, current_block_.value());

Expand Down Expand Up @@ -213,6 +222,7 @@ class CanonicalizePlanner : public ExprVisitor {

Map<Var, Var> trivial_bindings_;
Map<Var, Expr> known_bindings_;
Map<Var, Constant> known_bound_to_constant_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_inside_dataflow_;
// Set of vars either used outside a dataflow block altogether or outside their
// home dataflow block (the one where they were defined)
Expand Down Expand Up @@ -251,6 +261,9 @@ class BindingCanonicalizer : public ExprMutator {
while (auto opt = plan_.replace_usage.Get(new_var->vid)) {
new_var = opt.value();
}
if (auto opt = plan_.inline_constant.Get(new_var->vid)) {
return VisitExpr(opt.value());
}

return ExprMutator::VisitExpr_(new_var.get());
}
Expand Down
67 changes: 36 additions & 31 deletions tests/python/relax/test_transform_canonicalize_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,33 +289,6 @@ def main() -> R.Tensor((), "int32"):
verify(Input, Expected)


def test_unable_to_fold():
@I.ir_module
class MultipleUse:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
n = R.const(1)
# multiple uses -> cannot coalesce
m = R.add(n, n)
R.output(n)
return n

@I.ir_module
class ComplexExpr:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
y = R.const(1)
# y does not appear by itself -> cannot coalesce
n = R.add(y, y)
R.output(n)
return n

verify(MultipleUse, MultipleUse)
verify(ComplexExpr, ComplexExpr)


def test_multiple_outputs():
@I.ir_module
class Input:
Expand Down Expand Up @@ -380,18 +353,26 @@ def main():
verify(Input, Expected)


def test_multiply_used_in_outputs():
# cannot fold output in this case
def test_fold_const_to_output():
@I.ir_module
class UsedInMultipleOutputs:
class Before:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
n = R.const(1)
R.output(n)
return n

verify(UsedInMultipleOutputs, UsedInMultipleOutputs)
@I.ir_module
class Expected:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
n = R.const(1)
R.output(n)
return R.const(1)

verify(Before, Expected)


def test_canonicalize_var_to_dataflow_var_if_legal():
Expand Down Expand Up @@ -972,5 +953,29 @@ def transform_params(
assert_structural_equal(Expected, after)


def test_unwrap_tuple_of_constant():
@I.ir_module
class TestChainAssignments:
@R.function
def main():
tup = (R.const(0, "int64"), R.const(1, "int64"))
x = tup[0]
y = tup[1]
z = R.add(x, y)
return z

@I.ir_module
class Expected:
@R.function
def main():
tup = (R.const(0, "int64"), R.const(1, "int64"))
x = tup[0]
y = tup[1]
z = R.add(R.const(0, "int64"), R.const(1, "int64"))
return z

verify(TestChainAssignments, Expected)


if __name__ == "__main__":
tvm.testing.main()
2 changes: 1 addition & 1 deletion tests/python/relax/test_transform_convert_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def tuples_and_const(x: R.Tensor, y: R.Tensor) -> R.Tensor:
t2 = (y, y, x)
c = R.const([1, 2, 3], dtype="int32")
R.output(c)
return c
return R.const([1, 2, 3], dtype="int32")

@R.function
def shapes() -> R.Shape:
Expand Down