diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 439e2643380a..45d060567c68 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -99,6 +99,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { PrimExpr VisitExpr_(const VarNode* op) override; PrimExpr VisitExpr_(const BufferLoadNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; + Stmt VisitStmt_(const DeclBufferNode* op) override; Stmt VisitStmt_(const AllocateConstNode* op) override; LetStmt ToLetStmt(const PoolAllocation& pool_allocation, const Var& buffer_var, const Stmt& body); @@ -386,6 +387,16 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const BufferStoreNode* op) { return std::move(store); } +Stmt PoolAllocationToOffsetConverter::VisitStmt_(const DeclBufferNode* op) { + auto decl = Downcast(StmtExprMutator::VisitStmt_(op)); + + Buffer remapped = GetRemappedBuffer(decl->buffer); + if (!op->buffer.same_as(remapped)) { + decl.CopyOnWrite()->buffer = remapped; + } + return std::move(decl); +} + PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op) { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index d0403fcae938..03929c5436be 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -19,7 +19,7 @@ import pytest import tvm from tvm import PoolInfoProperties, WorkspacePoolInfo -from tvm.script import tir as T +from tvm.script import tir as T, ir as I from tvm.target import Target from tvm.tir import stmt_functor from tvm.tir.usmp import utils as usmp_utils @@ -67,6 +67,38 @@ def _assign_targets_to_primfuncs_irmodule(mod, target): return ret +def _plan_and_convert(tir_mod, pools=None): + target = Target("c") + + if pools is None: + pools = [ + WorkspacePoolInfo( + "global_workspace", + [target], + ) + ] + + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) + tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, pools) + main_func = tir_mod["__tvm_main__"] + buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_map = buffer_analysis.buffer_info_stmts + + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + buffer_info_arr = fcreate_array_bi(buffer_info_map) + fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") + buffer_pool_allocations = fusmp_algo_greedy_by_size( + buffer_info_arr, buffer_analysis.memory_pressure + ) + fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") + pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) + tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( + pool_allocations, emit_tvmscript_printable=True + )(tir_mod) + + return tir_mod_with_offsets + + # fmt: off @tvm.script.ir_module class LinearStructure: @@ -210,42 +242,24 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde def test_mobilenet_subgraph(): - target = Target("c") - fast_memory_pool = WorkspacePoolInfo( - "fast_memory", - [target], - PoolInfoProperties(size_hint_bytes=200704), - ) - slow_memory_pool = WorkspacePoolInfo( - "slow_memory", - [target], - ) - tir_mod = LinearStructure - tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) - tir_mod = assign_poolinfos_to_allocates_in_irmodule( - tir_mod, [fast_memory_pool, slow_memory_pool] - ) - main_func = tir_mod["__tvm_main__"] - buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) - buffer_info_map = buffer_analysis.buffer_info_stmts - - fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") - buffer_info_arr = fcreate_array_bi(buffer_info_map) - fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") - buffer_pool_allocations = fusmp_algo_greedy_by_size( - buffer_info_arr, buffer_analysis.memory_pressure - ) - fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") - pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) - tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( - pool_allocations, emit_tvmscript_printable=True - )(tir_mod) + before = LinearStructure - tir_mod_with_offsets_ref = LinearStructurePlanned + expected = LinearStructurePlanned - for gv, ref_func in tir_mod_with_offsets_ref.functions.items(): - actual_func = tir_mod_with_offsets[gv.name_hint] - tvm.ir.assert_structural_equal(actual_func, ref_func) + target = Target("c") + pools = [ + WorkspacePoolInfo( + "fast_memory", + [target], + PoolInfoProperties(size_hint_bytes=200704), + ), + WorkspacePoolInfo( + "slow_memory", + [target], + ), + ] + after = _plan_and_convert(before, pools=pools) + tvm.ir.assert_structural_equal(after, expected) # fmt: off @@ -500,35 +514,10 @@ def __tvm_main__(input: T.handle, global_workspace_0_var: T.handle("uint8"), out def test_resnet_subgraph(): - target = Target("c") - global_workspace_pool = WorkspacePoolInfo( - "global_workspace", - [target], - ) - tir_mod = ResnetStructure - tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) - tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) - main_func = tir_mod["__tvm_main__"] - buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) - buffer_info_map = buffer_analysis.buffer_info_stmts - - fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") - buffer_info_arr = fcreate_array_bi(buffer_info_map) - fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") - buffer_pool_allocations = fusmp_algo_greedy_by_size( - buffer_info_arr, buffer_analysis.memory_pressure - ) - fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") - pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) - tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( - pool_allocations, emit_tvmscript_printable=True - )(tir_mod) - - tir_mod_with_offsets_ref = ResnetStructurePlanned - - for gv, ref_func in tir_mod_with_offsets_ref.functions.items(): - actual_func = tir_mod_with_offsets[gv.name_hint] - tvm.ir.assert_structural_equal(actual_func, ref_func) + before = ResnetStructure + expected = ResnetStructurePlanned + after = _plan_and_convert(before) + tvm.ir.assert_structural_equal(after, expected) @tvm.script.ir_module @@ -591,36 +580,116 @@ def __tvm_main__( def test_tensor_intrin(): - target = Target("c") - global_workspace_pool = WorkspacePoolInfo( - "global_workspace", - [target], - ) - - tir_mod = TensorIntrinStructure - tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) - tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) - main_func = tir_mod["__tvm_main__"] - buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) - buffer_info_map = buffer_analysis.buffer_info_stmts - - fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") - buffer_info_arr = fcreate_array_bi(buffer_info_map) - fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") - buffer_pool_allocations = fusmp_algo_greedy_by_size( - buffer_info_arr, buffer_analysis.memory_pressure - ) - fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") - pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) - tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( - pool_allocations, emit_tvmscript_printable=True - )(tir_mod) - + before = TensorIntrinStructure + after = _plan_and_convert(before) expected = TensorIntrinStructurePlanned - - for gv, ref_func in expected.functions.items(): - actual_func = tir_mod_with_offsets[gv.name_hint] - tvm.ir.assert_structural_equal(actual_func, ref_func) + tvm.ir.assert_structural_equal(after, expected) + + +class TestMergeAllocations(tvm.testing.CompareBeforeAfter): + def transform(self): + return _plan_and_convert + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def __tvm_main__(A: T.Buffer(256, "int8"), D: T.Buffer(256, "int8")): + B = T.allocate([256], "int8") + T.call_extern("subroutine", A.data, B, dtype="int32") + C = T.allocate([256], "int8") + T.call_extern("subroutine", B, C, dtype="int32") + T.call_extern("subroutine", C, D.data, dtype="int32") + + @T.prim_func + def subroutine(A: T.Buffer(256, "int8"), B: T.Buffer(256, "int8")): + for i in range(256): + B[i] = A[i] + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def __tvm_main__( + A: T.Buffer(256, "int8"), + D: T.Buffer(256, "int8"), + workspace_var: T.handle("uint8"), + ): + workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16) + B: T.handle("int8") = T.address_of(workspace[256]) + T.call_extern("subroutine", A.data, B, workspace.data, dtype="int32") + C: T.handle("int8") = T.address_of(workspace[0]) + T.call_extern("subroutine", B, C, workspace.data, dtype="int32") + T.call_extern("subroutine", C, D.data, workspace.data, dtype="int32") + + @T.prim_func + def subroutine( + A: T.Buffer(256, "int8"), + B: T.Buffer(256, "int8"), + workspace_var: T.handle("uint8"), + ): + workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16) + for i in range(256): + B[i] = A[i] + + return mod + + +class TestMergeAllocationsWithDeclBuffer(tvm.testing.CompareBeforeAfter): + """Like TestMergeAllocations, but using T.decl_buffer""" + + def transform(self): + return _plan_and_convert + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def __tvm_main__(A: T.Buffer(256, "int8"), D: T.Buffer(256, "int8")): + B = T.decl_buffer([256], "int8") + T.call_extern("subroutine", A.data, B.data, dtype="int32") + C = T.decl_buffer([256], "int8") + T.call_extern("subroutine", B.data, C.data, dtype="int32") + T.call_extern("subroutine", C.data, D.data, dtype="int32") + + @T.prim_func + def subroutine(A: T.Buffer(256, "int8"), B: T.Buffer(256, "int8")): + for i in range(256): + B[i] = A[i] + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def __tvm_main__( + A: T.Buffer(256, "int8"), + D: T.Buffer(256, "int8"), + workspace_var: T.handle("uint8"), + ): + workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16) + B_data: T.handle("int8") = T.address_of(workspace[256]) + B = T.decl_buffer(256, "int8", data=B_data) + T.call_extern("subroutine", A.data, B.data, workspace.data, dtype="int32") + C_data: T.handle("int8") = T.address_of(workspace[0]) + C = T.decl_buffer(256, "int8", data=C_data) + T.call_extern("subroutine", B.data, C.data, workspace.data, dtype="int32") + T.call_extern("subroutine", C.data, D.data, workspace.data, dtype="int32") + + @T.prim_func + def subroutine( + A: T.Buffer(256, "int8"), + B: T.Buffer(256, "int8"), + workspace_var: T.handle("uint8"), + ): + workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16) + for i in range(256): + B[i] = A[i] + + return mod if __name__ == "__main__":