Skip to content

Commit

Permalink
[microNPU] enable USMP
Browse files Browse the repository at this point in the history
* Fixing unit tests
* Added a guard for USMP bufferinfo extraction ignore non-global
  allocates
* fixed export_model_library_format to use target_kind type

Change-Id: I9c6c90d8787c39697fca24af299f8309f40d3743
  • Loading branch information
manupak committed Jan 28, 2022
1 parent c1edeb8 commit a854d61
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 71 deletions.
99 changes: 51 additions & 48 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,37 +177,26 @@ def _build_function_memory_map(function_metadata):
"""
device_max_workspace = dict()
main_func_metadata = function_metadata[MAIN_FUNC_NAME_STR]
main_targets = dict(main_func_metadata.workspace_sizes).keys()
from tvm.driver import tvmc # pylint: disable=import-outside-toplevel

external_codegens = tvmc.composite_target.get_codegen_names()
func_entries = []
target_local_entries = dict()
for main_target in main_targets:
device_max_workspace[main_target] = 0
for func_name, finfo in function_metadata.items():
if func_name == MAIN_FUNC_NAME_STR:
continue
target_local_entries[func_name] = list()

for func_name, finfo in function_metadata.items():
# Skip a few unsupported cases:
# 1. The main function metadata is exported elsewhere.
# 2. BYOC operator implementations do not currently export useful FunctionInfo.
if func_name == MAIN_FUNC_NAME_STR or not finfo.tir_primfuncs:
continue
if main_target in finfo.workspace_sizes.keys():
workspace_size = finfo.workspace_sizes[main_target]
target_entry = {
"device": int(main_target.kind.device_type),
"workspace_size_bytes": int(workspace_size),
}
target_local_entries[func_name].append(target_entry)
if workspace_size > device_max_workspace.get(main_target, 0):
device_max_workspace[main_target] = workspace_size
# TODO(Mousius) - Remove this massive hack when Targets are unified
if main_target.kind.name in external_codegens:
device_max_workspace[main_target] += int(workspace_size)
for func_name, finfo in function_metadata.items():
# Skip a few unsupported cases:
# 1. The main function metadata is exported elsewhere.
# 2. BYOC operator implementations do not currently export useful FunctionInfo.
if func_name == MAIN_FUNC_NAME_STR or not finfo.tir_primfuncs:
continue
if func_name not in target_local_entries.keys():
target_local_entries[func_name] = list()
for target in dict(finfo.workspace_sizes).keys():
workspace_size = finfo.workspace_sizes[target]
target_entry = {
"device": int(target.kind.device_type),
"workspace_size_bytes": int(workspace_size),
}
target_local_entries[func_name].append(target_entry)
if workspace_size >= device_max_workspace.get(int(target.kind.device_type), 0):
device_max_workspace[int(target.kind.device_type)] = workspace_size

for func_name, target_entries_ in target_local_entries.items():
func_entry = {
Expand All @@ -216,32 +205,46 @@ def _build_function_memory_map(function_metadata):
}
func_entries.append(func_entry)

target_main_entries = list()
for main_target in main_targets:
main_func_local_workspace = main_func_metadata.workspace_sizes[main_target]
main_func_constants = (
main_func_metadata.constant_sizes[main_target]
if main_target in main_func_metadata.constant_sizes.keys()
else 0
target_main_entries = dict()

def _create_empty_entry(target_device_type):
return {
"device": int(target_device_type),
"workspace_size_bytes": 0,
"constants_size_bytes": 0,
"io_size_bytes": 0,
}

for target in dict(main_func_metadata.workspace_sizes).keys():
main_func_local_workspace = main_func_metadata.workspace_sizes[target]
target_main_entries[int(target.kind.device_type)] = _create_empty_entry(
int(target.kind.device_type)
)
main_func_io = (
main_func_metadata.io_sizes[main_target]
if main_target in main_func_metadata.io_sizes.keys()
else 0
target_main_entries[int(target.kind.device_type)]["workspace_size_bytes"] = int(
device_max_workspace[int(target.kind.device_type)]
) + int(main_func_local_workspace)

for target in dict(main_func_metadata.constant_sizes).keys():
if int(target.kind.device_type) not in target_main_entries.keys():
target_main_entries[int(target.kind.device_type)] = _create_empty_entry(
int(target.kind.device_type)
)
target_main_entries[int(target.kind.device_type)]["constants_size_bytes"] = int(
main_func_metadata.constant_sizes[target]
)
target_main_entries.append(
{
"device": int(main_target.kind.device_type),
"workspace_size_bytes": int(device_max_workspace[main_target])
+ int(main_func_local_workspace),
"constants_size_bytes": int(main_func_constants),
"io_size_bytes": int(main_func_io),
}

for target in dict(main_func_metadata.io_sizes).keys():
if int(target.kind.device_type) not in target_main_entries.keys():
target_main_entries[int(target.kind.device_type)] = _create_empty_entry(
int(target.kind.device_type)
)
target_main_entries[int(target.kind.device_type)]["io_size_bytes"] = int(
main_func_metadata.io_sizes[target]
)

ret = {
"operator_functions": func_entries,
"main": target_main_entries,
"main": list(target_main_entries.values()),
}
return ret

Expand Down
6 changes: 4 additions & 2 deletions src/tir/usmp/analysis/extract_buffer_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

#include <stack>

#include "../../../runtime/thread_storage_scope.h"

namespace tvm {
namespace tir {
namespace usmp {
Expand Down Expand Up @@ -257,14 +259,14 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) {
void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) {
ScopeInfo& current_scope_info = scope_stack_.top();
const auto& type = Downcast<PointerType>(op->buffer_var->type_annotation);
const auto& storage_scope = type->storage_scope;
const auto& storage_scope = runtime::StorageScope::Create(type->storage_scope);

// If the allocate is in a for loop, USMP currently only looks at serial for loops.
// If its not a serial for loop, then memory planner will omit them in the current memory planning
// process leaving them to as tir.allocate nodes for codegen. Additionally, the USMP can only work
// with buffers that have global storage_scope

if (storage_scope == "global") {
if (storage_scope.rank == runtime::StorageRank::kGlobal) {
if (!current_scope_info.for_loop.defined()) {
RecordAllocateNodeInfo(op);
} else if (current_scope_info.for_loop.defined() &&
Expand Down
10 changes: 6 additions & 4 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def deserialize_command_stream(blob):
return cmms


def create_test_runner(accel="ethos-u55-256"):
def create_test_runner(accel="ethos-u55-256", enable_usmp=True):
file_dir = os.path.dirname(os.path.abspath(__file__))
test_root = os.path.join(file_dir, "reference_system")
_, ethosu_variant, ethosu_macs = accel.split("-")
Expand All @@ -215,13 +215,15 @@ def create_test_runner(accel="ethos-u55-256"):
"relay.ext.ethos-u.options": {
"accelerator_config": accel,
},
"tir.usmp.enable": True,
"tir.usmp.enable": enable_usmp,
},
)


def build_source(module, inputs, outputs, accel="ethos-u55-256", output_tolerance=0):
test_runner = create_test_runner(accel)
def build_source(
module, inputs, outputs, accel="ethos-u55-256", output_tolerance=0, enable_usmp=True
):
test_runner = create_test_runner(accel, enable_usmp)
return compile_models(
models=AOTTestModel(
module=module,
Expand Down
5 changes: 3 additions & 2 deletions tests/python/contrib/test_ethosu/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32"]


def test_forward_mobilenet_v1(accel_type="ethos-u55-256"):
@pytest.mark.parametrize("enable_usmp", [True, False])
def test_forward_mobilenet_v1(enable_usmp, accel_type="ethos-u55-256"):
"""Test the Mobilenet V1 TF Lite model."""
np.random.seed(23)
tflite_model_file = tf_testing.get_workload_official(
Expand All @@ -59,7 +60,7 @@ def test_forward_mobilenet_v1(accel_type="ethos-u55-256"):

mod = partition_for_ethosu(relay_mod, params)
compiled_models = infra.build_source(
mod, input_data, output_data, accel_type, output_tolerance=10
mod, input_data, output_data, accel_type, output_tolerance=10, enable_usmp=enable_usmp
)
infra.verify_source(compiled_models, accel_type)

Expand Down
58 changes: 43 additions & 15 deletions tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def populate_ethosu_copy_calls(stmt):
{
"src": "placeholder_5",
"dest": "placeholder_d_global",
"length": 8,
"length": 32,
},
],
},
Expand Down Expand Up @@ -851,24 +851,45 @@ def _check_buffer(address, region, length, buffer_var):
length, dtype=buffer_dtype
)
elif buffer_type == tir_to_cs_translator.BufferType.scratch:
shape = list(buffer_info[buffer_var].shape)
assert length == np.prod(shape)
assert address < scratch_size
assert address < tvmbaw_workspace_size

size_in_bytes = int(np.prod(shape)) * dtype_bytes
size_in_bytes = allocate_node_sizes[buffer_var]
# Every buffer is adjusted to align to 16 bytes
size_in_bytes = util.round_up(size_in_bytes, 16)
assert address + size_in_bytes <= scratch_size
assert address + size_in_bytes <= tvmbaw_workspace_size
# The scratch area should not be used by any other buffer
assert not scratch_mask[address : address + size_in_bytes].any()
assert not tvmbaw_workspace_mask[address : address + size_in_bytes].any()
# The scratch area is marked as used
scratch_mask[address : address + size_in_bytes] = np.ones(size_in_bytes, dtype="uint8")
tvmbaw_workspace_mask[address : address + size_in_bytes] = np.ones(
size_in_bytes, dtype="uint8"
)
elif buffer_type == tir_to_cs_translator.BufferType.input:
assert address == 0
else:
assert buffer_type == tir_to_cs_translator.BufferType.output
assert address == 0

def _get_allocate_node_sizes(mod):
# There should only be a single function
assert len(mod.functions.items()) == 1
primfunc = mod.functions.items()[0][1]
_allocate_node_sizes = dict()

def analyze_remaining_allocates(stmt):
if isinstance(stmt, tvm.tir.stmt.Allocate):
allocate = stmt
pointer_type = allocate.buffer_var.type_annotation
storage_scope = pointer_type.storage_scope
if storage_scope == "global":
dtype_bytes = np.iinfo(np.dtype(allocate.dtype)).bits // 8
size_in_bytes = int(dtype_bytes * np.prod(list(allocate.extents)))
# Every memory address the NPU access have to be 16 byte aligned
size_in_bytes = util.round_up(size_in_bytes, 16)
_allocate_node_sizes[allocate.buffer_var] = size_in_bytes

tvm.tir.stmt_functor.post_order_visit(primfunc.body, analyze_remaining_allocates)
return _allocate_node_sizes

def verify(npu_ops):
"""This wrapper verifies the allocated addresses matches with original tir buffers"""
checked_buffers = set()
Expand Down Expand Up @@ -933,22 +954,29 @@ def check_buffer(address, region, length, buffer_var):
tir_mod = test_case["tir_module"]
tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u"))
tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
candidate_regions_for_scratch = [5, 2, 1]
(
scratch_region_map,
tvmbaw_workspace_size,
_,
) = tir_to_cs_translator.analyze_scratch_memory_acesses(
tir_mod, candidate_regions_for_scratch
)
allocate_node_sizes = _get_allocate_node_sizes(tir_mod)
buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"])
extern_calls = extract_call_extern_list(tir_mod)
_npu_ops = list()
for extern_call in extern_calls:
_npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call))
npu_op_tir_buffers = collect_tir_buffer_info(_npu_ops)
(
_npu_ops,
constant_hex_string,
scratch_size,
) = tir_to_cs_translator.assign_addresses(buffer_info, _npu_ops)
scratch_mask = np.zeros(scratch_size, dtype="uint8")
(_npu_ops, constant_hex_string) = tir_to_cs_translator.assign_addresses(
buffer_info, _npu_ops, scratch_region_map
)
tvmbaw_workspace_mask = np.zeros(tvmbaw_workspace_size, dtype="uint8")
constant_tensor_read_mask = np.zeros(len(constant_hex_string) // 2, dtype="uint8")
verify(_npu_ops)
# This will be only 1 if all allocated scratch is used.
assert np.prod(scratch_mask) == 1
assert np.prod(tvmbaw_workspace_mask) == 1
# This will be only 1 if all constant tensors is read at least once.
assert np.prod(constant_tensor_read_mask) == 1

Expand Down

0 comments on commit a854d61

Please sign in to comment.