Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][USMP] Preserve DeclBuffer in PoolAllocationToOffsetConverter #15044

Merged
merged 1 commit into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const DeclBufferNode* op) override;

Stmt VisitStmt_(const AllocateConstNode* op) override;
LetStmt ToLetStmt(const PoolAllocation& pool_allocation, const Var& buffer_var, const Stmt& body);
Expand Down Expand Up @@ -386,6 +387,16 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const BufferStoreNode* op) {
return std::move(store);
}

Stmt PoolAllocationToOffsetConverter::VisitStmt_(const DeclBufferNode* op) {
auto decl = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));

Buffer remapped = GetRemappedBuffer(decl->buffer);
if (!op->buffer.same_as(remapped)) {
decl.CopyOnWrite()->buffer = remapped;
}
return std::move(decl);
}

PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op) {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest
import tvm
from tvm import PoolInfoProperties, WorkspacePoolInfo
from tvm.script import tir as T
from tvm.script import tir as T, ir as I
from tvm.target import Target
from tvm.tir import stmt_functor
from tvm.tir.usmp import utils as usmp_utils
Expand Down Expand Up @@ -67,6 +67,38 @@ def _assign_targets_to_primfuncs_irmodule(mod, target):
return ret


def _plan_and_convert(tir_mod, pools=None):
target = Target("c")

if pools is None:
pools = [
WorkspacePoolInfo(
"global_workspace",
[target],
)
]

tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, pools)
main_func = tir_mod["__tvm_main__"]
buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = buffer_analysis.buffer_info_stmts

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
buffer_pool_allocations = fusmp_algo_greedy_by_size(
buffer_info_arr, buffer_analysis.memory_pressure
)
fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
pool_allocations, emit_tvmscript_printable=True
)(tir_mod)

return tir_mod_with_offsets


# fmt: off
@tvm.script.ir_module
class LinearStructure:
Expand Down Expand Up @@ -210,42 +242,24 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde


def test_mobilenet_subgraph():
target = Target("c")
fast_memory_pool = WorkspacePoolInfo(
"fast_memory",
[target],
PoolInfoProperties(size_hint_bytes=200704),
)
slow_memory_pool = WorkspacePoolInfo(
"slow_memory",
[target],
)
tir_mod = LinearStructure
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = assign_poolinfos_to_allocates_in_irmodule(
tir_mod, [fast_memory_pool, slow_memory_pool]
)
main_func = tir_mod["__tvm_main__"]
buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = buffer_analysis.buffer_info_stmts

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
buffer_pool_allocations = fusmp_algo_greedy_by_size(
buffer_info_arr, buffer_analysis.memory_pressure
)
fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
pool_allocations, emit_tvmscript_printable=True
)(tir_mod)
before = LinearStructure

tir_mod_with_offsets_ref = LinearStructurePlanned
expected = LinearStructurePlanned

for gv, ref_func in tir_mod_with_offsets_ref.functions.items():
actual_func = tir_mod_with_offsets[gv.name_hint]
tvm.ir.assert_structural_equal(actual_func, ref_func)
target = Target("c")
pools = [
WorkspacePoolInfo(
"fast_memory",
[target],
PoolInfoProperties(size_hint_bytes=200704),
),
WorkspacePoolInfo(
"slow_memory",
[target],
),
]
after = _plan_and_convert(before, pools=pools)
tvm.ir.assert_structural_equal(after, expected)


# fmt: off
Expand Down Expand Up @@ -500,35 +514,10 @@ def __tvm_main__(input: T.handle, global_workspace_0_var: T.handle("uint8"), out


def test_resnet_subgraph():
target = Target("c")
global_workspace_pool = WorkspacePoolInfo(
"global_workspace",
[target],
)
tir_mod = ResnetStructure
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
main_func = tir_mod["__tvm_main__"]
buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = buffer_analysis.buffer_info_stmts

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
buffer_pool_allocations = fusmp_algo_greedy_by_size(
buffer_info_arr, buffer_analysis.memory_pressure
)
fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
pool_allocations, emit_tvmscript_printable=True
)(tir_mod)

tir_mod_with_offsets_ref = ResnetStructurePlanned

for gv, ref_func in tir_mod_with_offsets_ref.functions.items():
actual_func = tir_mod_with_offsets[gv.name_hint]
tvm.ir.assert_structural_equal(actual_func, ref_func)
before = ResnetStructure
expected = ResnetStructurePlanned
after = _plan_and_convert(before)
tvm.ir.assert_structural_equal(after, expected)


@tvm.script.ir_module
Expand Down Expand Up @@ -591,36 +580,116 @@ def __tvm_main__(


def test_tensor_intrin():
target = Target("c")
global_workspace_pool = WorkspacePoolInfo(
"global_workspace",
[target],
)

tir_mod = TensorIntrinStructure
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
main_func = tir_mod["__tvm_main__"]
buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = buffer_analysis.buffer_info_stmts

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
buffer_pool_allocations = fusmp_algo_greedy_by_size(
buffer_info_arr, buffer_analysis.memory_pressure
)
fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
pool_allocations, emit_tvmscript_printable=True
)(tir_mod)

before = TensorIntrinStructure
after = _plan_and_convert(before)
expected = TensorIntrinStructurePlanned

for gv, ref_func in expected.functions.items():
actual_func = tir_mod_with_offsets[gv.name_hint]
tvm.ir.assert_structural_equal(actual_func, ref_func)
tvm.ir.assert_structural_equal(after, expected)


class TestMergeAllocations(tvm.testing.CompareBeforeAfter):
def transform(self):
return _plan_and_convert

def before(self):
@I.ir_module
class mod:
@T.prim_func
def __tvm_main__(A: T.Buffer(256, "int8"), D: T.Buffer(256, "int8")):
B = T.allocate([256], "int8")
T.call_extern("subroutine", A.data, B, dtype="int32")
C = T.allocate([256], "int8")
T.call_extern("subroutine", B, C, dtype="int32")
T.call_extern("subroutine", C, D.data, dtype="int32")

@T.prim_func
def subroutine(A: T.Buffer(256, "int8"), B: T.Buffer(256, "int8")):
for i in range(256):
B[i] = A[i]

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def __tvm_main__(
A: T.Buffer(256, "int8"),
D: T.Buffer(256, "int8"),
workspace_var: T.handle("uint8"),
):
workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16)
B: T.handle("int8") = T.address_of(workspace[256])
T.call_extern("subroutine", A.data, B, workspace.data, dtype="int32")
C: T.handle("int8") = T.address_of(workspace[0])
T.call_extern("subroutine", B, C, workspace.data, dtype="int32")
T.call_extern("subroutine", C, D.data, workspace.data, dtype="int32")

@T.prim_func
def subroutine(
A: T.Buffer(256, "int8"),
B: T.Buffer(256, "int8"),
workspace_var: T.handle("uint8"),
):
workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16)
for i in range(256):
B[i] = A[i]

return mod


class TestMergeAllocationsWithDeclBuffer(tvm.testing.CompareBeforeAfter):
"""Like TestMergeAllocations, but using T.decl_buffer"""

def transform(self):
return _plan_and_convert

def before(self):
@I.ir_module
class mod:
@T.prim_func
def __tvm_main__(A: T.Buffer(256, "int8"), D: T.Buffer(256, "int8")):
B = T.decl_buffer([256], "int8")
T.call_extern("subroutine", A.data, B.data, dtype="int32")
C = T.decl_buffer([256], "int8")
T.call_extern("subroutine", B.data, C.data, dtype="int32")
T.call_extern("subroutine", C.data, D.data, dtype="int32")

@T.prim_func
def subroutine(A: T.Buffer(256, "int8"), B: T.Buffer(256, "int8")):
for i in range(256):
B[i] = A[i]

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def __tvm_main__(
A: T.Buffer(256, "int8"),
D: T.Buffer(256, "int8"),
workspace_var: T.handle("uint8"),
):
workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16)
B_data: T.handle("int8") = T.address_of(workspace[256])
B = T.decl_buffer(256, "int8", data=B_data)
T.call_extern("subroutine", A.data, B.data, workspace.data, dtype="int32")
C_data: T.handle("int8") = T.address_of(workspace[0])
C = T.decl_buffer(256, "int8", data=C_data)
T.call_extern("subroutine", B.data, C.data, workspace.data, dtype="int32")
T.call_extern("subroutine", C.data, D.data, workspace.data, dtype="int32")

@T.prim_func
def subroutine(
A: T.Buffer(256, "int8"),
B: T.Buffer(256, "int8"),
workspace_var: T.handle("uint8"),
):
workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16)
for i in range(256):
B[i] = A[i]

return mod


if __name__ == "__main__":
Expand Down
Loading