Skip to content

Commit

Permalink
Enhance arith to support predicate simplifications for uncommon predi…
Browse files Browse the repository at this point in the history
…cates
  • Loading branch information
tqchen committed Feb 5, 2024
1 parent b72cd0f commit 39d5375
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
9 changes: 9 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2143,6 +2143,15 @@ Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, R
/*simplify_trivial_iterators=*/simplify_trivial_iterators);
Array<IterSumExpr> rewrite = res->indices;

if (rewrite.empty() && !is_one(input_pred) && check_level != IterMapLevel::Bijective) {
// The input predicate may cause detect iter map to fail
// but we can still detect the iter map without the input predicate
// in which case the resulting iter map is valid and can be used for simplification.
rewrite = DetectIterMap(indices, input_iters, const_true(), check_level, ana,
/*simplify_trivial_iterators=*/simplify_trivial_iterators)
->indices;
}

if (rewrite.empty()) {
return indices;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,7 @@ def expected(A: T.Buffer((14, 32), "int32")):
with T.init():
B[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, 0, dtype="int32")
B[vi, vj] = T.if_then_else(
vi == 3 and 2 <= vj,
0,
B[vj // 4 + vi, vj % 4] + A[vi * 4 + vj, vk],
dtype="int32",
vi == 3 and 2 <= vj, 0, B[vi, vj] + A[vi * 4 + vj, vk], dtype="int32"
)


Expand Down

0 comments on commit 39d5375

Please sign in to comment.