Skip to content

Commit

Permalink
[TIR] Additional Stmt/Expr simplication rules
Browse files Browse the repository at this point in the history
- Enabled simplification of `A[i] = A[i] + 0` into no-op.  This was a
  bug introduced in apache#9727, which
  applied this rewrite only to `A[i] = A[i]`, and not to statements
  which simplify to `A[i] = A[i]`.  Regression test added to prevent
  reoccurrence of this bug.

- Enabled simplification of `x - x` to zero for floating point types.
  Previously, this simplification was applied only for data types that
  could be used as buffer indices.
  • Loading branch information
Lunderberg committed May 25, 2022
1 parent 59d4c02 commit e6f0238
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 17 deletions.
14 changes: 8 additions & 6 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), broadcast(x - y, lanes));
}

// cancelation rules
TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x),
SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - x, y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF(x - (y + x), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF(x - (x + y), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);

if (IsIndexType(op->dtype)) {
// Index rules
// cancelation rules
TVM_TRY_REWRITE((x + y) - y, x);
TVM_TRY_REWRITE((x + y) - x, y);
TVM_TRY_REWRITE(x - (y + x), 0 - y);
TVM_TRY_REWRITE(x - (x + y), 0 - y);

TVM_TRY_REWRITE(min(x, y) - x, min(0, y - x));
TVM_TRY_REWRITE(min(x, y) - y, min(x - y, 0));
Expand All @@ -275,7 +278,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_REWRITE(y - min(x, y), max(y - x, 0));

// mul co-efficient folding
TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x));
TVM_TRY_REWRITE(x * y - x, x * (y - 1));
TVM_TRY_REWRITE(y * x - x, x * (y - 1));
TVM_TRY_REWRITE(x - y * x, x * (1 - y));
Expand Down
12 changes: 6 additions & 6 deletions src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
// eliminate useless stores
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
if (const BufferLoadNode* load = op->value.as<BufferLoadNode>()) {
if (load->buffer->data.same_as(op->buffer->data) &&
ArrayDeepEqual(load->indices, op->indices) &&
tir::ExprDeepEqual()(load->buffer->elem_offset, op->buffer->elem_offset) &&
ArrayDeepEqual(load->buffer->shape, op->buffer->shape) &&
ArrayDeepEqual(load->buffer->strides, op->buffer->strides)) {
if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
if (load->buffer->data.same_as(store->buffer->data) &&
ArrayDeepEqual(load->indices, store->indices) &&
tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) &&
ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
return Evaluate(0);
}
}
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,5 +972,13 @@ def test_div_zero_simplify():
assert "division by zero" in str(cm.execption)


def test_sub_bufferload():
ck = RewriteChecker()
buf = tvm.tir.decl_buffer([1], dtype="float32")
load = tvm.tir.BufferLoad(buf, [0])
expr = load - load
ck.verify(expr, 0.0)


if __name__ == "__main__":
pytest.main([__file__])
43 changes: 38 additions & 5 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import tvm
from tvm import te
from tvm.script import tir as T


def test_stmt_simplify():
Expand Down Expand Up @@ -133,9 +134,41 @@ def sls(n, d):
assert "if" not in str(stmt)


def test_load_store_noop():
"""Store of a value that was just read from the same location is a no-op."""

@T.prim_func
def before(A: T.Buffer[(1,), "float32"]):
A[0] = A[0]

@T.prim_func
def expected(A: T.Buffer[(1,), "float32"]):
T.evaluate(0)

after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"]
tvm.ir.assert_structural_equal(after, expected)


def test_load_store_noop_after_simplify():
"""As test_load_store_noop, but requiring simplification to identify.
Previously, a bug caused the self-assignment of a buffer to
checked based on the pre-simplification assignment, not the
post-simplification. This test is to identify any similar
regression.
"""

@T.prim_func
def before(A: T.Buffer[(1,), "float32"]):
A[0] = A[0] + (5.0 - 5.0)

@T.prim_func
def expected(A: T.Buffer[(1,), "float32"]):
T.evaluate(0)

after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"]
tvm.ir.assert_structural_equal(after, expected)


if __name__ == "__main__":
test_stmt_simplify()
test_thread_extent_simplify()
test_if_likely()
test_basic_likely_elimination()
test_complex_likely_elimination()
sys.exit(pytest.main(sys.argv))

0 comments on commit e6f0238

Please sign in to comment.