Skip to content

Commit

Permalink
Enhance IterMapSimplify to support uncommon predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 5, 2024
1 parent b72cd0f commit 9671656
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
9 changes: 9 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2143,6 +2143,15 @@ Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, R
/*simplify_trivial_iterators=*/simplify_trivial_iterators);
Array<IterSumExpr> rewrite = res->indices;

if (rewrite.empty() && !is_one(input_pred) && check_level != IterMapLevel::Bijective) {
// The input predicate may cause detect iter map to fail
// but we can still detect the iter map without the input predicate
// in which case the resulting iter map is valid and can be used for simplification.
rewrite = DetectIterMap(indices, input_iters, const_true(), check_level, ana,
/*simplify_trivial_iterators=*/simplify_trivial_iterators)
->indices;
}

if (rewrite.empty()) {
return indices;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,7 @@ def expected(A: T.Buffer((14, 32), "int32")):
with T.init():
B[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, 0, dtype="int32")
B[vi, vj] = T.if_then_else(
vi == 3 and 2 <= vj,
0,
B[vj // 4 + vi, vj % 4] + A[vi * 4 + vj, vk],
dtype="int32",
vi == 3 and 2 <= vj, 0, B[vi, vj] + A[vi * 4 + vj, vk], dtype="int32"
)


Expand All @@ -767,7 +764,7 @@ def expected(A: T.Buffer((14, 32), "int32")):
for k in T.serial(32):
with T.block("block"):
B[i, j] = T.if_then_else(
i == 3 and 2 <= j, 0, B[j // 4 + i, j % 4] + A[i * 4 + j, k], dtype="int32"
i == 3 and 2 <= j, 0, B[i, j] + A[i * 4 + j, k], dtype="int32"
)


Expand Down

0 comments on commit 9671656

Please sign in to comment.