Skip to content

Commit

Permalink
Simplify nested if_then_else when constant is appearing in then_expr
Browse files Browse the repository at this point in the history
  • Loading branch information
rutkoor committed Dec 14, 2023
1 parent a9fcac1 commit 4998a87
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
WithRecordIterPredicate(cond, [&] { true_value = this->VisitExpr(op->args[1]); });
}
{
With<ConstraintContext> constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond)));
false_value = this->VisitExpr(op->args[2]);
PrimExpr not_cond = Not(cond);
With<ConstraintContext> constraint(analyzer_, not_cond);
WithRecordIterPredicate(not_cond, [&] { false_value = this->VisitExpr(op->args[2]); });
}
if (is_zero(cond)) {
return false_value;
Expand Down
22 changes: 22 additions & 0 deletions tests/python/tir-transform/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,5 +1757,27 @@ def expected(a: T.handle):
A[T.int64(1)] = T.float32(0)


class TestNestedIfElimination(BaseBeforeAfter):
def before(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")):
for i0, j0 in T.grid(2, 8):
with T.block("P"):
i1, j1 = T.axis.remap("SS", [i0, j0])
T.reads(a[i1, j1])
T.writes(b[i1, j1])
b[i1, j1] = T.if_then_else(
i1 == 1 and 6 <= j1,
0,
T.max(0, T.if_then_else(i1 == 1 and 6 <= j1, 0, a[i1, j1])),
)

def expected(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")):
for i0, j0 in T.grid(2, 8):
with T.block("P"):
i1, j1 = T.axis.remap("SS", [i0, j0])
T.reads(a[i1, j1])
T.writes(b[i1, j1])
b[i1, j1] = T.if_then_else(i1 == 1 and 6 <= j1, 0, T.max(0, a[i1, j1]))


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 4998a87

Please sign in to comment.