Skip to content

Commit

Permalink
Simplify nested if_then_else when constant is appearing in then_expr
Browse files Browse the repository at this point in the history
  • Loading branch information
rutkoor committed Dec 13, 2023
1 parent a9fcac1 commit 3e039f2
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 13 deletions.
43 changes: 30 additions & 13 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,36 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
if (op->op.same_as(op_if_then_else)) {
PrimExpr cond = this->VisitExpr(op->args[0]);
PrimExpr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
WithRecordIterPredicate(cond, [&] { true_value = this->VisitExpr(op->args[1]); });
}
{
With<ConstraintContext> constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond)));
false_value = this->VisitExpr(op->args[2]);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
if (is_const_number(op->args[2]) || !is_const_number(op->args[1])) {
{
With<ConstraintContext> constraint(analyzer_, cond);
WithRecordIterPredicate(cond, [&] { true_value = this->VisitExpr(op->args[1]); });
}
{
With<ConstraintContext> constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond)));
false_value = this->VisitExpr(op->args[2]);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
}
} else {
{
With<ConstraintContext> constraint(analyzer_, cond);
WithRecordIterPredicate(cond, [&] { false_value = this->VisitExpr(op->args[2]); });
}
{
With<ConstraintContext> constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond)));
true_value = this->VisitExpr(op->args[1]);
}
if (is_zero(cond)) {
return true_value;
}
if (is_one(cond)) {
return false_value;
}
}
if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) {
Expand Down
22 changes: 22 additions & 0 deletions tests/python/tir-transform/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,5 +1757,27 @@ def expected(a: T.handle):
A[T.int64(1)] = T.float32(0)


class TestNestedIfElimination(BaseBeforeAfter):
def before(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")):
for i0, j0 in T.grid(2, 8):
with T.block("P"):
i1, j1 = T.axis.remap("SS", [i0, j0])
T.reads(a[i1, j1])
T.writes(b[i1, j1])
b[i1, j1] = T.if_then_else(
i1 == 1 and 6 <= j1,
0,
T.max(0, T.if_then_else(i1 == 1 and 6 <= j1, 0, a[i1, j1])),
)

def expected(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")):
for i0, j0 in T.grid(2, 8):
with T.block("P"):
i1, j1 = T.axis.remap("SS", [i0, j0])
T.reads(a[i1, j1])
T.writes(b[i1, j1])
b[i1, j1] = T.if_then_else(i1 == 1 and 6 <= j1, 0, T.max(0, a[i1, j1]))


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 3e039f2

Please sign in to comment.