Skip to content

Commit

Permalink
[TIR][USMP] Preserve DeclBuffer in PoolAllocationToOffsetConverter (#…
Browse files Browse the repository at this point in the history
…15044)

Previously, `PoolAllocationToOffsetConverter` did not remap buffer
objects occurring in `DeclBuffer` nodes.  This commit updates
`PoolAllocationToOffsetConverter` to handle `DeclBuffer` nodes. This
is a subset of changes, being split out from
#14778 into independent portions.
  • Loading branch information
Lunderberg committed Jun 16, 2023
1 parent 3f2aa68 commit fa8a9f7
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 93 deletions.
11 changes: 11 additions & 0 deletions src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const DeclBufferNode* op) override;

Stmt VisitStmt_(const AllocateConstNode* op) override;
LetStmt ToLetStmt(const PoolAllocation& pool_allocation, const Var& buffer_var, const Stmt& body);
Expand Down Expand Up @@ -386,6 +387,16 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const BufferStoreNode* op) {
return std::move(store);
}

Stmt PoolAllocationToOffsetConverter::VisitStmt_(const DeclBufferNode* op) {
auto decl = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));

Buffer remapped = GetRemappedBuffer(decl->buffer);
if (!op->buffer.same_as(remapped)) {
decl.CopyOnWrite()->buffer = remapped;
}
return std::move(decl);
}

PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op) {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest
import tvm
from tvm import PoolInfoProperties, WorkspacePoolInfo
from tvm.script import tir as T
from tvm.script import tir as T, ir as I
from tvm.target import Target
from tvm.tir import stmt_functor
from tvm.tir.usmp import utils as usmp_utils
Expand Down Expand Up @@ -67,6 +67,38 @@ def _assign_targets_to_primfuncs_irmodule(mod, target):
return ret


def _plan_and_convert(tir_mod, pools=None):
target = Target("c")

if pools is None:
pools = [
WorkspacePoolInfo(
"global_workspace",
[target],
)
]

tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, pools)
main_func = tir_mod["__tvm_main__"]
buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = buffer_analysis.buffer_info_stmts

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
buffer_pool_allocations = fusmp_algo_greedy_by_size(
buffer_info_arr, buffer_analysis.memory_pressure
)
fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
pool_allocations, emit_tvmscript_printable=True
)(tir_mod)

return tir_mod_with_offsets


# fmt: off
@tvm.script.ir_module
class LinearStructure:
Expand Down Expand Up @@ -210,42 +242,24 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde


def test_mobilenet_subgraph():
target = Target("c")
fast_memory_pool = WorkspacePoolInfo(
"fast_memory",
[target],
PoolInfoProperties(size_hint_bytes=200704),
)
slow_memory_pool = WorkspacePoolInfo(
"slow_memory",
[target],
)
tir_mod = LinearStructure
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = assign_poolinfos_to_allocates_in_irmodule(
tir_mod, [fast_memory_pool, slow_memory_pool]
)
main_func = tir_mod["__tvm_main__"]
buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = buffer_analysis.buffer_info_stmts

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
buffer_pool_allocations = fusmp_algo_greedy_by_size(
buffer_info_arr, buffer_analysis.memory_pressure
)
fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
pool_allocations, emit_tvmscript_printable=True
)(tir_mod)
before = LinearStructure

tir_mod_with_offsets_ref = LinearStructurePlanned
expected = LinearStructurePlanned

for gv, ref_func in tir_mod_with_offsets_ref.functions.items():
actual_func = tir_mod_with_offsets[gv.name_hint]
tvm.ir.assert_structural_equal(actual_func, ref_func)
target = Target("c")
pools = [
WorkspacePoolInfo(
"fast_memory",
[target],
PoolInfoProperties(size_hint_bytes=200704),
),
WorkspacePoolInfo(
"slow_memory",
[target],
),
]
after = _plan_and_convert(before, pools=pools)
tvm.ir.assert_structural_equal(after, expected)


# fmt: off
Expand Down Expand Up @@ -500,35 +514,10 @@ def __tvm_main__(input: T.handle, global_workspace_0_var: T.handle("uint8"), out


def test_resnet_subgraph():
target = Target("c")
global_workspace_pool = WorkspacePoolInfo(
"global_workspace",
[target],
)
tir_mod = ResnetStructure
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
main_func = tir_mod["__tvm_main__"]
buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = buffer_analysis.buffer_info_stmts

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
buffer_pool_allocations = fusmp_algo_greedy_by_size(
buffer_info_arr, buffer_analysis.memory_pressure
)
fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
pool_allocations, emit_tvmscript_printable=True
)(tir_mod)

tir_mod_with_offsets_ref = ResnetStructurePlanned

for gv, ref_func in tir_mod_with_offsets_ref.functions.items():
actual_func = tir_mod_with_offsets[gv.name_hint]
tvm.ir.assert_structural_equal(actual_func, ref_func)
before = ResnetStructure
expected = ResnetStructurePlanned
after = _plan_and_convert(before)
tvm.ir.assert_structural_equal(after, expected)


@tvm.script.ir_module
Expand Down Expand Up @@ -591,36 +580,116 @@ def __tvm_main__(


def test_tensor_intrin():
target = Target("c")
global_workspace_pool = WorkspacePoolInfo(
"global_workspace",
[target],
)

tir_mod = TensorIntrinStructure
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
main_func = tir_mod["__tvm_main__"]
buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = buffer_analysis.buffer_info_stmts

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
buffer_pool_allocations = fusmp_algo_greedy_by_size(
buffer_info_arr, buffer_analysis.memory_pressure
)
fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
pool_allocations, emit_tvmscript_printable=True
)(tir_mod)

before = TensorIntrinStructure
after = _plan_and_convert(before)
expected = TensorIntrinStructurePlanned

for gv, ref_func in expected.functions.items():
actual_func = tir_mod_with_offsets[gv.name_hint]
tvm.ir.assert_structural_equal(actual_func, ref_func)
tvm.ir.assert_structural_equal(after, expected)


class TestMergeAllocations(tvm.testing.CompareBeforeAfter):
def transform(self):
return _plan_and_convert

def before(self):
@I.ir_module
class mod:
@T.prim_func
def __tvm_main__(A: T.Buffer(256, "int8"), D: T.Buffer(256, "int8")):
B = T.allocate([256], "int8")
T.call_extern("subroutine", A.data, B, dtype="int32")
C = T.allocate([256], "int8")
T.call_extern("subroutine", B, C, dtype="int32")
T.call_extern("subroutine", C, D.data, dtype="int32")

@T.prim_func
def subroutine(A: T.Buffer(256, "int8"), B: T.Buffer(256, "int8")):
for i in range(256):
B[i] = A[i]

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def __tvm_main__(
A: T.Buffer(256, "int8"),
D: T.Buffer(256, "int8"),
workspace_var: T.handle("uint8"),
):
workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16)
B: T.handle("int8") = T.address_of(workspace[256])
T.call_extern("subroutine", A.data, B, workspace.data, dtype="int32")
C: T.handle("int8") = T.address_of(workspace[0])
T.call_extern("subroutine", B, C, workspace.data, dtype="int32")
T.call_extern("subroutine", C, D.data, workspace.data, dtype="int32")

@T.prim_func
def subroutine(
A: T.Buffer(256, "int8"),
B: T.Buffer(256, "int8"),
workspace_var: T.handle("uint8"),
):
workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16)
for i in range(256):
B[i] = A[i]

return mod


class TestMergeAllocationsWithDeclBuffer(tvm.testing.CompareBeforeAfter):
"""Like TestMergeAllocations, but using T.decl_buffer"""

def transform(self):
return _plan_and_convert

def before(self):
@I.ir_module
class mod:
@T.prim_func
def __tvm_main__(A: T.Buffer(256, "int8"), D: T.Buffer(256, "int8")):
B = T.decl_buffer([256], "int8")
T.call_extern("subroutine", A.data, B.data, dtype="int32")
C = T.decl_buffer([256], "int8")
T.call_extern("subroutine", B.data, C.data, dtype="int32")
T.call_extern("subroutine", C.data, D.data, dtype="int32")

@T.prim_func
def subroutine(A: T.Buffer(256, "int8"), B: T.Buffer(256, "int8")):
for i in range(256):
B[i] = A[i]

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def __tvm_main__(
A: T.Buffer(256, "int8"),
D: T.Buffer(256, "int8"),
workspace_var: T.handle("uint8"),
):
workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16)
B_data: T.handle("int8") = T.address_of(workspace[256])
B = T.decl_buffer(256, "int8", data=B_data)
T.call_extern("subroutine", A.data, B.data, workspace.data, dtype="int32")
C_data: T.handle("int8") = T.address_of(workspace[0])
C = T.decl_buffer(256, "int8", data=C_data)
T.call_extern("subroutine", B.data, C.data, workspace.data, dtype="int32")
T.call_extern("subroutine", C.data, D.data, workspace.data, dtype="int32")

@T.prim_func
def subroutine(
A: T.Buffer(256, "int8"),
B: T.Buffer(256, "int8"),
workspace_var: T.handle("uint8"),
):
workspace = T.match_buffer(workspace_var, 512, "uint8", strides=[1], align=16)
for i in range(256):
B[i] = A[i]

return mod


if __name__ == "__main__":
Expand Down

0 comments on commit fa8a9f7

Please sign in to comment.