From 7c45f4cd116b002b3b5cb14f8b7b421336e325fa Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 12 Apr 2023 20:25:40 +0900 Subject: [PATCH 1/2] clean --- src/relax/transform/fuse_tir.cc | 39 +++++++++++---------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 183d395b6078..57257fe5cc67 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -117,15 +117,13 @@ class SymbolicMatcher : ExprFunctor& buffer_map, - const Map& var_map) { + explicit FuseTIRBufferSubstitutor(const Map& buffer_map, + const Map& var_map) { buffer_remap_ = buffer_map; var_remap_ = var_map; - for (const auto& kv : buffer_map) { - const Buffer& src = kv.first; - const Buffer& tgt = kv.second; + for (const auto& [src, tgt] : buffer_map) { var_remap_.Set(src->data, tgt->data); } } @@ -351,10 +349,8 @@ class FusedTIRConstructor : public ExprVisitor { // It's a symbolic shape var, no need to alloc Buffers. continue; } - auto ret = CreateParamsAndBuffers(GetStructInfo(relax_param), // - relax_param->name_hint()); - const Array& params = ret.first; - const Array& buffers = ret.second; + auto [params, buffers] = CreateParamsAndBuffers(GetStructInfo(relax_param), // + relax_param->name_hint()); ICHECK_EQ(params.size(), buffers.size()); for (size_t i = 0; i < params.size(); ++i) { func_info_.buffer_map.Set(params[i], buffers[i]); @@ -384,10 +380,8 @@ class FusedTIRConstructor : public ExprVisitor { // Step 4. Append symbolic vars const relax::Var& last_relax_param = func->params.back(); if (GetStructInfo(last_relax_param)->IsInstance()) { - auto ret = + auto [params, buffers] = CreateParamsAndBuffers(GetStructInfo(last_relax_param), last_relax_param->name_hint()); - const Array& params = ret.first; - const Array& buffers = ret.second; ICHECK(buffers.empty()); for (size_t i = 0; i < params.size(); ++i) { func_info_.params.push_back(params[i]); @@ -682,9 +676,7 @@ class FusedTIRConstructor : public ExprVisitor { "list."; if (index == -1) index = 0; for (size_t i = 0; i < tuple->fields.size(); ++i) { - auto ret = CreateParamsAndBuffers(tuple->fields[i], name_hint, index); - const Array& ret_params = ret.first; - const Array& ret_buffers = ret.second; + auto [ret_params, ret_buffers] = CreateParamsAndBuffers(tuple->fields[i], name_hint, index); ICHECK_EQ(ret_params.size(), ret_buffers.size()); // Adding tuple field results to the end of params and buffers. params.insert(params.end(), ret_params.begin(), ret_params.end()); @@ -714,19 +706,18 @@ class FusedTIRConstructor : public ExprVisitor { tir::PrimFunc ConstructFunc() { Map attr_map; attr_map.Set("tir.noalias", tir::const_true()); - tir::FuseTIRBufferSubstitor substitor(func_info_.buffer_subst_map, - func_info_.symbolic_var_remap); + tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, func_info_.symbolic_var_remap); ICHECK(func_info_.global_name != "fused"); // Remove output buffers from func_info_.alloc_buffers Array alloc_buffers; for (const tir::Buffer& buf : func_info_.alloc_buffers) { if (func_info_.output_buffers.count(buf.get()) == 0) { - alloc_buffers.push_back(substitor.SubstituteAllocatedBuffer(buf)); + alloc_buffers.push_back(subst.SubstituteAllocatedBuffer(buf)); } } tir::Stmt body = tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies)); - body = substitor.Substitute(body); + body = subst.Substitute(body); body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, alloc_buffers); body = tir::BlockRealize({}, Bool(true), Downcast(body)); tir::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, @@ -804,9 +795,7 @@ class TIRFuseMutator : public ExprMutator { // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder. TIRFuseMutator mutator(mod); // Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_` - for (const auto& kv : mod->functions) { - const GlobalVar& gv = kv.first; - const BaseFunc& func = kv.second; + for (const auto& [gv, func] : mod->functions) { // Only fuse primitive relax functions if (func->IsInstance() && func->HasNonzeroAttr(attr::kPrimitive)) { tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv); @@ -816,9 +805,7 @@ class TIRFuseMutator : public ExprMutator { // Step 2. Update all non-primitive relax functions and add it, with the dependent function, // into the new IRModule - for (const auto& kv : mod->functions) { - const GlobalVar& gv = kv.first; - const BaseFunc& func = kv.second; + for (const auto& [gv, func] : mod->functions) { if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { relax::Function update_func = Downcast(mutator.VisitExpr(func)); mutator.builder_->AddFunction(update_func, gv->name_hint); From d7606b99aba69c6757e6e0f635dacfb4295f57de Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 12 Apr 2023 20:26:19 +0900 Subject: [PATCH 2/2] add test --- src/relax/transform/fuse_tir.cc | 4 - tests/python/relax/test_transform_fuse_tir.py | 120 ++++++++++++++++++ 2 files changed, 120 insertions(+), 4 deletions(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 57257fe5cc67..432ddca0a751 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -244,8 +244,6 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { Map buffer_remap_; /*! \brief Mapping from src tir var to tgt var. */ Map var_remap_; - /*! \brief The structural equality checker */ - StructuralEqual structural_equal_; Array UnionAccessRegion(const Array& regions) const { // For now we only allow Buffer access the same elements. @@ -260,8 +258,6 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { if (it == buffer_region_set.end()) { ret.push_back(region); buffer_region_set[region->buffer.get()] = region->region; - } else { - ICHECK(structural_equal_(region->region, it->second)); } } diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 8b856d3cc598..c7aa7984be88 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -883,5 +883,125 @@ def main( _check(Before, Expected) +def test_same_buffer_multiple_read(): + @I.ir_module + class Module: + @T.prim_func + def concatenate( + rxplaceholder: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"), + rxplaceholder_1: T.Buffer( + (T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32" + ), + T_concat: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"), + ): + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)): + with T.block("T_concat"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads( + rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3], + rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], + ) + T.writes(T_concat[v_ax0, v_ax1, v_ax2, v_ax3]) + T_concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else( + T.int64(1) <= v_ax0, + rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3], + rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], + ) + + @T.prim_func + def transpose2( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"), + T_transpose: T.Buffer((T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32"), + ): + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[ + v_ax0, v_ax3, v_ax1, v_ax2 + ] + + @R.function + def fused_concatenate_transpose2( + inp_0: R.Tensor((1, 4, 64, 64), dtype="float32") + ) -> R.Tensor((2, 64, 64, 4), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Module + with R.dataflow(): + lv = R.call_tir( + cls.concatenate, + (inp_0, inp_0), + out_sinfo=R.Tensor((2, 4, 64, 64), dtype="float32"), + ) + gv = R.call_tir( + cls.transpose2, (lv,), out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32") + ) + R.output(gv) + return gv + + @R.function + def main( + inp_0: R.Tensor((1, 4, 64, 64), dtype="float32") + ) -> R.Tensor((2, 64, 64, 4), dtype="float32"): + R.func_attr({"num_input": 3}) + cls = Module + with R.dataflow(): + lv = cls.fused_concatenate_transpose2(inp_0) + R.output(lv) + return lv + + @I.ir_module + class Expected: + @T.prim_func + def fused_concatenate_transpose2( + inp_0: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"), + T_transpose_handle_intermediate: T.Buffer( + (T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32" + ), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + T_concat_handle_intermediate = T.alloc_buffer( + (T.int64(2), T.int64(4), T.int64(64), T.int64(64)) + ) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)): + with T.block("T_concat"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(inp_0[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3]) + T.writes(T_concat_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T_concat_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else( + T.int64(1) <= v_ax0, + inp_0[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3], + inp_0[v_ax0, v_ax1, v_ax2, v_ax3], + ) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_concat_handle_intermediate[v_ax0, v_ax3, v_ax1, v_ax2]) + T.writes(T_transpose_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose_handle_intermediate[ + v_ax0, v_ax1, v_ax2, v_ax3 + ] = T_concat_handle_intermediate[v_ax0, v_ax3, v_ax1, v_ax2] + + @R.function + def main( + inp_0: R.Tensor((1, 4, 64, 64), dtype="float32") + ) -> R.Tensor((2, 64, 64, 4), dtype="float32"): + R.func_attr({"num_input": 3}) + cls = Expected + with R.dataflow(): + lv = R.call_tir( + cls.fused_concatenate_transpose2, + (inp_0,), + out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32"), + ) + R.output(lv) + return lv + + _check(Module, Expected) + + if __name__ == "__main__": tvm.testing.main()