From 3431b22956dc67238827e15551e7cafa6810da5c Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 22 Apr 2022 10:52:05 -0700 Subject: [PATCH 1/2] Prevent IRSbustitute to create new buffer when buffer var is unchanged --- src/tir/ir/stmt_functor.cc | 5 +++-- tests/cpp/ir_functor_test.cc | 42 ++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 06933c2c0dcb..34bbb4b46ba4 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -690,9 +690,10 @@ class IRSubstitute : public StmtExprMutator { return it->second; } - if (auto mapped_var = vmap_(buf->data)) { + auto new_buffer_var = vmap_(buf->data); + if (new_buffer_var.defined() && !new_buffer_var.value().same_as(buf->data)) { auto writer = buf.CopyOnWrite(); - writer->data = Downcast(mapped_var); + writer->data = Downcast(new_buffer_var); } buf_remap_[key] = buf; diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index d02c38f3afac..49b1c4cd9236 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -325,3 +325,45 @@ TEST(IRF, StmtMutator) { ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x)); } } + +TEST(IRF, Substitute) { + using namespace tvm; + using namespace tvm::tir; + DataType dtype = DataType::Float(32); + Var x("x", PointerType(PrimType(dtype), "")); + auto fmaketest = [&]() { + Buffer buffer{/*data=*/x, + /*dtype=*/DataType::Float(32), + /*shape=*/{}, + /*strides=*/{}, + /*elem_offset=*/NullValue(), + /*name=*/"buf", + /*data_alignment=*/1, + /*offset_factor=*/1, + /*buffer_type=*/BufferType::kDefault}; + return BufferLoad(buffer, {}); + }; + + { + // test substitute buffer var + Var y = x.copy_with_suffix("subst"); + BufferLoad buffer_load = fmaketest(); + auto f_subst = [&](const Var& var) -> Optional { + if (var.same_as(x)) { + return y; + } + return NullOpt; + }; + BufferLoad new_buffer_load = Downcast(Substitute(buffer_load, f_subst)); + ICHECK(new_buffer_load->buffer->data.same_as(y)); + } + + { + // test identity substition + PrimExpr expr = fmaketest(); + auto f_subst = [&](const Var& var) -> Optional { return var; }; + PrimExpr new_expr = Substitute(expr, f_subst); + // the expression is not changed + ICHECK(new_expr.same_as(expr)); + } +} From 7d0f88c99aab8993b41e9ecc08b41ceb30d2f075 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 22 Apr 2022 11:43:01 -0700 Subject: [PATCH 2/2] typo --- tests/cpp/ir_functor_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 49b1c4cd9236..33b145d3a41d 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -359,7 +359,7 @@ TEST(IRF, Substitute) { } { - // test identity substition + // test identity substitution PrimExpr expr = fmaketest(); auto f_subst = [&](const Var& var) -> Optional { return var; }; PrimExpr new_expr = Substitute(expr, f_subst);