Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1419,10 +1419,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) {
// Case 1. 0 <= xn < d
divisible.CopyOnWrite()->DivideBy(gcd);
return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype));
} else if (extra->args.size() == 1 &&
} else if (extra->args.size() == 1 && extra->args[0]->scale == 1 &&
extra->args[0]->upper_factor != ConstIntBoundNode::kPosInf &&
extra->args[0]->upper_factor % (gcd * extra->args[0]->lower_factor) == 0) {
Comment thread
tqchen marked this conversation as resolved.
// Case 2. xn == yn % m, where m % d == 0
// Case 2. xn == ((yn % m) // L), scale = +1, m % (d*L) == 0.
// S + xn < 0 with S divisible by d ⇔ S/d + xn // d < 0, because
// xn % d ∈ [0, d) lets us drop the remainder via the Case 1 argument,
// and xn // d = (yn // (d*L)) % (m/(d*L)).
// The scale must be +1: with scale = -1 the equivalence becomes ≤
// rather than <, so the rewrite would strengthen the predicate and
// silently drop the boundary S/d == xn // d (e.g. row > col where
// row and col are independent projections of the same lane id).
divisible.CopyOnWrite()->DivideBy(gcd);
const auto split_expr = extra->args[0];
int64_t lower_factor = gcd * extra->args[0]->lower_factor;
Expand Down
44 changes: 44 additions & 0 deletions tests/python/arith/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,5 +490,49 @@ def test_simplify_le():
ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0)


def test_simplify_le_negative_scale_extra():
"""Regression: Case 2 of the LT-with-divisible-coeffs rewrite must not
fire when the leftover split term has a negative scale.

The rewrite ``S + xn < 0 ⇔ S/d + xn // d < 0`` is only sound when
the leftover ``xn`` has scale ``+1``. With scale ``-1`` the equivalence
becomes ``≤`` rather than ``<`` and the rewrite silently strengthens
the predicate. The original bug surfaced as ``row > col`` masks of
``.16x*b`` tcgen05 readbacks collapsing to plain ``warp_id > k``
comparisons (lower-triangle writes were silently dropped on the
boundary warp).
"""
ck = CanonicalChecker()
tx = tvm.tirx.Var("tx", "int32")
warp = tvm.tirx.Var("warp", "int32")
ck.analyzer.bind(tx, tvm.ir.Range(0, 128))
ck.analyzer.bind(warp, tvm.ir.Range(0, 4))

# Same-source joint projection: the comparison genuinely depends on tx
# at warp == 0 (e.g. tx == 4 ⇒ 0 < 1 = True; tx == 1 ⇒ 2 < 0 = False),
# so the simplifier must keep both sides. Pre-fix this folded to
# ``0 < warp`` and dropped every True case in warp 0.
expr = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4
ck.verify(expr, expr)

# The simpler ``scale = -1`` with ``lower_factor = 1`` shape. Pre-fix
# this folded to ``False`` (drops all warp >= 1 cases where the rhs
# actually exceeds 8*warp).
expr = warp * 8 < (tx % 32)
ck.verify(expr, expr)

# The corresponding ``scale = +1`` Case 2 path (the rewrite this guards)
# must still optimize — verifies we did not over-restrict.
x1 = tvm.tirx.Var("x1", "int32")
y1 = tvm.tirx.Var("y1", "int32")
ck.verify(x1 * 64 + (y1 % 64) < 120, x1 * 8 + (y1 % 64) // 8 < 15)

# The truly-always-true comparison that arises from the same kernel
# (``r = 2 / va = 1`` in the tcgen05.ld.16x256b readback) must still
# fold to True so the masked store can be elided.
expr_true = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4 + 8
ck.verify(expr_true, tvm.tirx.const(True, "bool"))


if __name__ == "__main__":
tvm.testing.main()
Loading