[Arith] Gate canonical-simplify LT Case 2 on extra scale == +1#19669
Conversation
CanonicalSimplifier::Impl::VisitExpr_(LTNode) Case 2 rewrites a
"scaled-by-d sum plus a single leftover split" comparison
S + xn < 0 <=> S/d + (xn // d) < 0 where d = gcd(scales)
into one where the leftover yn % m gets replaced by
floormod(floordiv(yn, d*L), m/(d*L)). The Case 1 derivation that
justifies dropping the remainder xn % d in [0, d) only works when
xn >= 0. With scale = -1 the equivalence becomes <= rather than <, and
the rewrite silently strengthens the predicate by dropping the boundary
case S/d == xn // d.
This surfaced as a miscompile in kernels that mask a per-lane write by
`row > col`, where `row = (lane_id // 4) + 16 * warp_id` and
`col = 2 * (lane_id % 4)` are independent projections of the same lane
id. After CSE+inlining the comparison hit canonical_simplify with the
divided projection on the LHS (scale = -1), and Case 2 folded
`2*(tx%4) < 16*warp + (tx%32)//4` into a plain `0 < warp_id`, zeroing
every thread that should have written `val` in warp 0. The same path
also folded other configurations (e.g. `0 < (tx%32) - 8*warp`) all the
way to False.
Gate Case 2 with `extra->args[0]->scale == 1`. The original target
shape (`(yn % m)` with positive scale and lower_factor=1, as well as
the scale=+1 + lower_factor>1 generalization) is unchanged; both are
covered by the existing `test_simplify_le` cases and by the new
`test_simplify_le_negative_scale_extra` regression test, which also
pins the buggy scale=-1 shape to its unsimplified form and re-asserts
that the truly-always-true `r=2` variant still folds to True.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request restricts Case 2 of the LT-with-divisible-coeffs rewrite in the canonical simplifier to only fire when the leftover split term has a positive scale (scale == 1). This prevents incorrect simplifications when the scale is negative, which previously led to silently dropping boundary conditions. A regression test has been added to verify this behavior. The reviewer suggested adding a defensive check against potential signed integer overflow during the multiplication of gcd and lower_factor to avoid undefined behavior.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
cc: @tqchen |
Summary
CanonicalSimplifier::Impl::VisitExpr_(LTNode)Case 2 rewritesThe Case 1 derivation only works when
xn ≥ 0. Withscale = -1the equivalence becomes≤rather than<, and the rewrite silently strengthens the predicate by dropping the boundaryS/d == xn // d.After CSE/inlining, a comparison such as
2*(tx%4) < 16*warp + (tx%32)//4(whererowandcolare independent projections of the same lane id) reaches canonical_simplify with the divided projection on the LHS (scale = -1), and Case 2 folds it to a plain0 < warp_id— zeroing every thread that should have writtenvalin warp 0. The same path also folds other configurations (e.g.0 < (tx%32) - 8*warp) all the way toFalse.The fix gates Case 2 with
extra->args[0]->scale == 1. The original target shape (yn % mwith positive scale andlower_factor=1, plus thescale = +1 / lower_factor > 1generalization) is unchanged; truly-always-true comparisons still fold toTrue.Test plan
test_simplify_le_negative_scale_extraintests/python/arith/test_arith_canonical_simplify.py— asserts on simplifiedPrimExpr, no GPU required; pre-fix fails, post-fix passes. It also pins the buggyscale = -1shapes to their unsimplified form, confirms thescale = +1Case 2 path still optimizes, and re-asserts the truly-always-true variant still folds toTrue.test_simplify_le(the original Case 2 target withscale = +1) still passes.tests/python/arith/test_arith_canonical_simplify.py— 16 passed.tests/python/arith/— 932 passed (1 pre-existing flaky random-seed failure intest_arith_solve_linear_equations.pyunrelated to this change, passes on rerun).