diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index 8cb01dfe6d074..fd2c27dcd154f 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -50,6 +50,18 @@ class RenewDefMutator : public StmtExprMutator { for (const auto& param : func->params) { params.push_back(generator.ReDefineVar(param)); } + for (const auto& param : func->params) { + if (param->dtype.is_handle()) { + const Buffer& buffer = func->buffer_map.at(param); + for (const PrimExpr& e : buffer->shape) { + if (const auto* v = e.as()) { + if (generator.remap_.count(GetRef(v)) == 0) { + generator.ReDefineVar(GetRef(v)); + } + } + } + } + } // Redefine buffers in order // TODO(Siyuan Feng): checking var is used after define Map buffer_map; diff --git a/tests/python/unittest/test_tir_renew_defs.py b/tests/python/unittest/test_tir_renew_defs.py index e01f5ecb12ead..5747cf5e4ad5d 100644 --- a/tests/python/unittest/test_tir_renew_defs.py +++ b/tests/python/unittest/test_tir_renew_defs.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. -import pytest import sys +import pytest import tvm import tvm.testing from tvm.script import tir as T @@ -76,6 +76,7 @@ def elementwise(A: T.Buffer((128, 128), "float32")): assert f1.body.block.body.loop_var != f2.body.block.body.loop_var # check remap of j assert f1.body.block.body.body.loop_var != f2.body.block.body.body.loop_var + # check inner block def _get_block(f): return f.body.block.body.body.body.block @@ -169,5 +170,21 @@ def symbolic_func(a: T.handle, b: T.handle, n: T.int32): tvm.ir.assert_structural_equal(f1, f2) +def test_buffer_map(): + @T.prim_func + def main(a: T.handle, b: T.handle): + m = T.int64() + A = T.match_buffer(a, (m * 2,)) + B = T.match_buffer(b, (m, 2)) + for i, j in T.grid(m, 2): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi * 2 + vj] + + f1 = main + f2 = tvm.tir.stmt_functor.renew_defs(main) + tvm.ir.assert_structural_equal(f1, f2) + + if __name__ == "__main__": tvm.testing.main()