Skip to content

Commit

Permalink
[Relax][Transform] Canonicalize let var = R.const bindings (#16601)
Browse files Browse the repository at this point in the history
* [Relax][Transform] Canonicalize `let var = R.const` bindings

Prior to this commit, known tuples could be unwrapped into variables,
but any constants would remain.  This commit updates
`CanonicalizeBindings` to unwrap tuples containing constants.

* Fix broken tests, removing test_unable_to_fold

The `test_unable_to_fold` was ported from the
`FoldDataflowBlockOutput` tests, and has been updated enough that it
no longer a purpose.
  • Loading branch information
Lunderberg committed Feb 19, 2024
1 parent f0c28a0 commit 36ebcd0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 32 deletions.
13 changes: 13 additions & 0 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct CanonicalizationPlan {
Map<Id, Var> replace_usage;
Map<Id, Var> replace_binding;
std::unordered_set<Id, ObjectPtrHash, ObjectPtrEqual> bindings_to_remove;
Map<Id, Constant> inline_constant;
};

/*! \brief Utility class to identify usage location
Expand Down Expand Up @@ -69,6 +70,10 @@ class CanonicalizePlanner : public ExprVisitor {
}
}

for (const auto& [var, constant] : visitor.known_bound_to_constant_) {
plan.inline_constant.Set(var->vid, constant);
}

for (const auto& binding_iter : visitor.trivial_bindings_) {
Var bound_var = binding_iter.first;
Var bound_to = binding_iter.second;
Expand Down Expand Up @@ -180,6 +185,10 @@ class CanonicalizePlanner : public ExprVisitor {
trivial_bindings_.Set(binding->var, parent.value());
}

if (auto constant = value.as<Constant>()) {
known_bound_to_constant_.Set(binding->var, constant.value());
}

known_bindings_.Set(binding->var, value);
def_blocks_.Set(binding->var, current_block_.value());

Expand Down Expand Up @@ -213,6 +222,7 @@ class CanonicalizePlanner : public ExprVisitor {

Map<Var, Var> trivial_bindings_;
Map<Var, Expr> known_bindings_;
Map<Var, Constant> known_bound_to_constant_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_inside_dataflow_;
// Set of vars either used outside a dataflow block altogether or outside their
// home dataflow block (the one where they were defined)
Expand Down Expand Up @@ -251,6 +261,9 @@ class BindingCanonicalizer : public ExprMutator {
while (auto opt = plan_.replace_usage.Get(new_var->vid)) {
new_var = opt.value();
}
if (auto opt = plan_.inline_constant.Get(new_var->vid)) {
return VisitExpr(opt.value());
}

return ExprMutator::VisitExpr_(new_var.get());
}
Expand Down
67 changes: 36 additions & 31 deletions tests/python/relax/test_transform_canonicalize_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,33 +289,6 @@ def main() -> R.Tensor((), "int32"):
verify(Input, Expected)


def test_unable_to_fold():
@I.ir_module
class MultipleUse:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
n = R.const(1)
# multiple uses -> cannot coalesce
m = R.add(n, n)
R.output(n)
return n

@I.ir_module
class ComplexExpr:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
y = R.const(1)
# y does not appear by itself -> cannot coalesce
n = R.add(y, y)
R.output(n)
return n

verify(MultipleUse, MultipleUse)
verify(ComplexExpr, ComplexExpr)


def test_multiple_outputs():
@I.ir_module
class Input:
Expand Down Expand Up @@ -380,18 +353,26 @@ def main():
verify(Input, Expected)


def test_multiply_used_in_outputs():
# cannot fold output in this case
def test_fold_const_to_output():
@I.ir_module
class UsedInMultipleOutputs:
class Before:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
n = R.const(1)
R.output(n)
return n

verify(UsedInMultipleOutputs, UsedInMultipleOutputs)
@I.ir_module
class Expected:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
n = R.const(1)
R.output(n)
return R.const(1)

verify(Before, Expected)


def test_canonicalize_var_to_dataflow_var_if_legal():
Expand Down Expand Up @@ -972,5 +953,29 @@ def transform_params(
assert_structural_equal(Expected, after)


def test_unwrap_tuple_of_constant():
@I.ir_module
class TestChainAssignments:
@R.function
def main():
tup = (R.const(0, "int64"), R.const(1, "int64"))
x = tup[0]
y = tup[1]
z = R.add(x, y)
return z

@I.ir_module
class Expected:
@R.function
def main():
tup = (R.const(0, "int64"), R.const(1, "int64"))
x = tup[0]
y = tup[1]
z = R.add(R.const(0, "int64"), R.const(1, "int64"))
return z

verify(TestChainAssignments, Expected)


if __name__ == "__main__":
tvm.testing.main()
2 changes: 1 addition & 1 deletion tests/python/relax/test_transform_convert_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def tuples_and_const(x: R.Tensor, y: R.Tensor) -> R.Tensor:
t2 = (y, y, x)
c = R.const([1, 2, 3], dtype="int32")
R.output(c)
return c
return R.const([1, 2, 3], dtype="int32")

@R.function
def shapes() -> R.Shape:
Expand Down

0 comments on commit 36ebcd0

Please sign in to comment.