From ff3ca9eb32d7be59148dce74067435e0cd5eaa71 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 21 Jan 2022 10:50:32 +0000 Subject: [PATCH] [microNPU] enable USMP This commit enables USMP in the microNPU codegen and tests. Change-Id: Iafd7db8cd678f2b3cca8c06e5ea30e79a570faf9 --- include/tvm/tir/usmp/utils.h | 13 +- python/tvm/micro/model_library_format.py | 54 ++-- .../contrib/ethosu/tir_to_cs_translator.py | 230 +++++++++++++----- src/relay/backend/aot_executor_codegen.cc | 5 +- src/tir/usmp/analysis/extract_buffer_info.cc | 12 +- src/tir/usmp/transform/assign_pool_info.cc | 17 +- .../convert_pool_allocations_to_offsets.cc | 4 +- src/tir/usmp/utils.cc | 7 +- tests/python/contrib/test_ethosu/infra.py | 1 + 9 files changed, 236 insertions(+), 107 deletions(-) diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index 582399865d6fd..da56121a3e578 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -251,24 +251,24 @@ struct AllocatedPoolInfoNode : public Object { PoolInfo pool_info; /*! \brief The allocated size into this pool */ Integer allocated_size; - /*! \brief An optional associated pool Var*/ - Optional pool_var; + /*! \brief An optional associated pool Var index of PrimFunc params*/ + Optional pool_var_idx; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pool_info", &pool_info); v->Visit("allocated_size", &allocated_size); - v->Visit("pool_var", &pool_var); + v->Visit("pool_var_idx", &pool_var_idx); } bool SEqualReduce(const AllocatedPoolInfoNode* other, SEqualReducer equal) const { return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) && - equal(pool_var, other->pool_var); + equal(pool_var_idx, other->pool_var_idx); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(pool_info); hash_reduce(allocated_size); - hash_reduce(pool_var); + hash_reduce(pool_var_idx); } static constexpr const char* _type_key = "tir.usmp.AllocatedPoolInfo"; @@ -277,7 +277,8 @@ struct AllocatedPoolInfoNode : public Object { class AllocatedPoolInfo : public ObjectRef { public: - TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var = Var()); + TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, + Integer pool_var_idx = Integer()); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode); }; diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 9f65a0bef1096..a14c20957765f 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -177,14 +177,13 @@ def _build_function_memory_map(function_metadata): """ device_max_workspace = dict() main_func_metadata = function_metadata[MAIN_FUNC_NAME_STR] - num_targets = len(main_func_metadata.workspace_sizes.items()) + main_targets = dict(main_func_metadata.workspace_sizes).keys() from tvm.driver import tvmc # pylint: disable=import-outside-toplevel external_codegens = tvmc.composite_target.get_codegen_names() func_entries = [] target_local_entries = dict() - for i in range(num_targets): - main_target = main_func_metadata.workspace_sizes.items()[i][0] + for main_target in main_targets: device_max_workspace[main_target] = 0 for func_name, finfo in function_metadata.items(): if func_name == MAIN_FUNC_NAME_STR: @@ -197,22 +196,18 @@ def _build_function_memory_map(function_metadata): # 2. BYOC operator implementations do not currently export useful FunctionInfo. if func_name == MAIN_FUNC_NAME_STR or not finfo.tir_primfuncs: continue - assert ( - len(finfo.constant_sizes.items()) == num_targets - ), f"{func_name}: found {finfo.constant_sizes!r} vs {num_targets}" - assert len(finfo.io_sizes.items()) == num_targets - target = finfo.workspace_sizes.items()[i][0] - workspace_size = finfo.workspace_sizes.items()[i][1] - target_entry = { - "device": int(target.kind.device_type), - "workspace_size_bytes": int(workspace_size), - } - target_local_entries[func_name].append(target_entry) - if workspace_size > device_max_workspace.get(target, 0): - device_max_workspace[target] = workspace_size - # TODO(Mousius) - Remove this massive hack when Targets are unified - if target.kind.name in external_codegens: - device_max_workspace[main_target] += int(workspace_size) + if main_target in finfo.workspace_sizes.keys(): + workspace_size = finfo.workspace_sizes[main_target] + target_entry = { + "device": int(main_target.kind.device_type), + "workspace_size_bytes": int(workspace_size), + } + target_local_entries[func_name].append(target_entry) + if workspace_size > device_max_workspace.get(main_target, 0): + device_max_workspace[main_target] = workspace_size + # TODO(Mousius) - Remove this massive hack when Targets are unified + if main_target.kind.name in external_codegens: + device_max_workspace[main_target] += int(workspace_size) for func_name, target_entries_ in target_local_entries.items(): func_entry = { @@ -222,15 +217,22 @@ def _build_function_memory_map(function_metadata): func_entries.append(func_entry) target_main_entries = list() - for i in range(num_targets): - target = main_func_metadata.workspace_sizes.items()[i][0] - main_func_local_workspace = main_func_metadata.workspace_sizes.items()[i][1] - main_func_constants = main_func_metadata.constant_sizes.items()[i][1] - main_func_io = main_func_metadata.io_sizes.items()[i][1] + for main_target in main_targets: + main_func_local_workspace = main_func_metadata.workspace_sizes[main_target] + main_func_constants = ( + main_func_metadata.constant_sizes[main_target] + if main_target in main_func_metadata.constant_sizes.keys() + else 0 + ) + main_func_io = ( + main_func_metadata.io_sizes[main_target] + if main_target in main_func_metadata.io_sizes.keys() + else 0 + ) target_main_entries.append( { - "device": int(target.kind.device_type), - "workspace_size_bytes": int(device_max_workspace[target]) + "device": int(main_target.kind.device_type), + "workspace_size_bytes": int(device_max_workspace[main_target]) + int(main_func_local_workspace), "constants_size_bytes": int(main_func_constants), "io_size_bytes": int(main_func_io), diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index ecea6eb28f098..0275b7c19d374 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -42,15 +42,6 @@ class BufferType(Enum): shram = auto() -_REGION_MAP = { - BufferType.constant: 0, - BufferType.scratch: 1, - BufferType.input: 3, - BufferType.output: 4, - BufferType.shram: int((1 << 8) | (3 << 0)), -} - - class BufferInfo(NamedTuple): """A data structure to hold metadata of the buffer.""" @@ -81,6 +72,107 @@ def get_accelerator_arch_config(accel_type): return accel_config_str_map[accel_type] +class RegionOffset(NamedTuple): + """A data structure to hold region and address offset corresponding to a tensor""" + + region: int + offset: int + + +def analyze_scratch_memory_acesses(mod: tvm.IRModule, candidate_regions_for_scratch: List[int]): + """ + Parameters + ---------- + mod: tvm.IRModule + The TIR module containing ethosu extern calls + candidate_regions_for_scratch: List[int] + A list of region integers that could be used for scratch regions + + Returns + ------- + scratch_region_map : Dict[tvm.tir.Var, int] + A map between buffer vars to scratch regions they are assigned + tvm_backend_alloc_workspace_size : int + The size of tvm_backend_alloc_workspace call required to service + remaining allocate nodes if any + tvm_backend_alloc_workspace_region : int + The region associated with the tvm_backend_alloc_workspace + """ + scratch_region_map = dict() + pool_var_region_map = dict() + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + if "pool_args" in primfunc.attrs.keys(): + pool_args = primfunc.attrs["pool_args"] + for pool_arg in pool_args: + pool_param = primfunc.params[int(pool_arg.pool_var_idx)] + pool_var_region_map[pool_param] = candidate_regions_for_scratch.pop() + scratch_region_map[pool_param] = RegionOffset( + region=pool_var_region_map[pool_param], offset=None + ) + + def analyze_pool_access(stmt): + if isinstance(stmt, tvm.tir.stmt.LetStmt): + call_address_of = stmt.value + load = call_address_of.args[0] + pool_var = load.buffer_var + scratch_region_map[stmt.var] = RegionOffset( + region=pool_var_region_map[pool_var], offset=int(load.index) + ) + + tvm.tir.stmt_functor.post_order_visit(primfunc.body, analyze_pool_access) + + tvmbaw_region = None + if len(candidate_regions_for_scratch) > 0: + tvmbaw_region = candidate_regions_for_scratch.pop() + + # Need a mutable data structure to be updated by the following function + # Therefore, using a list instead of int + tvmbaw_size = [0] + + # If there are tir.Allocate remaining by now, they need to be serviced via + # TVMBAW calls. + def analyze_remaining_allocates(stmt): + if isinstance(stmt, tvm.tir.stmt.Allocate): + allocate = stmt + pointer_type = allocate.buffer_var.type_annotation + storage_scope = pointer_type.storage_scope + if storage_scope == "global": + dtype_bytes = np.iinfo(np.dtype(allocate.dtype)).bits // 8 + size_in_bytes = int(dtype_bytes * np.prod(list(allocate.extents))) + # Every memory address the NPU access have to be 16 byte aligned + size_in_bytes = util.round_up(size_in_bytes, 16) + address = tvmbaw_size[0] + tvmbaw_size[0] += size_in_bytes + scratch_region_map[allocate.buffer_var] = RegionOffset( + region=tvmbaw_region, offset=address + ) + + tvm.tir.stmt_functor.post_order_visit(primfunc.body, analyze_remaining_allocates) + + return ( + scratch_region_map, + tvmbaw_size[0], + tvmbaw_region, + ) + + +def _get_region(buffer_type, var=None, scratch_region_map=None): + """A helper to obtain regions for buffer_types and buffer vars""" + static_regions = { + BufferType.constant: 0, + BufferType.input: 3, + BufferType.output: 4, + BufferType.shram: int((1 << 8) | (3 << 0)), + } + if buffer_type in static_regions.keys(): + return static_regions[buffer_type] + assert buffer_type == BufferType.scratch + assert var in scratch_region_map.keys(), f"{var} is not analyzed for scratch regions" + return scratch_region_map[var].region + + def translate(tir_module, params): """This will take an tir module for the NPU and compile to command stream @@ -105,22 +197,27 @@ def translate(tir_module, params): base_addresses : List[util.BaseAddress] base addresses to be used by the driver """ - + candidate_regions_for_scratch = [5, 2, 1] + ( + scratch_region_map, + tvmbaw_workspace_size, + tvmbaw_region, + ) = analyze_scratch_memory_acesses(tir_module, candidate_regions_for_scratch) buffer_info = extract_buffer_info(tir_module, params) call_extern_list = extract_call_extern_list(tir_module) _npu_ops = list() for call_extern in call_extern_list: _npu_ops.append(translate_ethosu_tir_call_extern(call_extern)) - _npu_ops, constant_data, scratch_size = assign_addresses(buffer_info, _npu_ops) - base_addresses = extract_param_base_addresses(tir_module, buffer_info) - if scratch_size > 0: + _npu_ops, constant_data = assign_addresses(buffer_info, _npu_ops, scratch_region_map) + base_addresses = extract_param_base_addresses(tir_module, buffer_info, scratch_region_map) + if tvmbaw_workspace_size: base_addresses.append( util.BaseAddress( - "scratch", - None, - _REGION_MAP[BufferType.scratch], - scratch_size, - True, + name="tvmbaw", + primfunc_param_idx=None, + region=tvmbaw_region, + size=tvmbaw_workspace_size, + is_runtime_allocation=True, ) ) target_accel_config = vela_api.get_accelerator_config() @@ -129,7 +226,7 @@ def translate(tir_module, params): return payload.hex(), constant_data, base_addresses -def extract_param_base_addresses(mod, buffer_info) -> List[util.BaseAddress]: +def extract_param_base_addresses(mod, buffer_info, scratch_region_map) -> List[util.BaseAddress]: """This function extracts base addresses to be used by the driver Parameters @@ -161,7 +258,12 @@ def extract_param_base_addresses(mod, buffer_info) -> List[util.BaseAddress]: element_size_bytes = np.iinfo(dtype).bits // 8 size_bytes = element_size_bytes * np.prod(list(buffer.shape)) base_addresses.append( - util.BaseAddress(param.name, idx, _REGION_MAP[buffer_info[param].btype], size_bytes) + util.BaseAddress( + param.name, + idx, + _get_region(buffer_info[param].btype, param, scratch_region_map), + size_bytes, + ) ) idx += 1 @@ -227,39 +329,42 @@ def extract_buffer_info( const_data, const_data.shape, const_data.dtype, BufferType.constant ) - for param in primfunc.params: + pool_param_indices = list() + if "pool_args" in primfunc.attrs.keys(): + pool_args = primfunc.attrs["pool_args"] + pool_param_indices = [allocated_pool_info.pool_var_idx for allocated_pool_info in pool_args] + + for idx, param in enumerate(primfunc.params): if param not in buffer_info.keys(): + if idx in pool_param_indices: + btype = BufferType.scratch + else: + btype = BufferType.input_or_output buffer_info[param] = BufferInfo( None, None, None, - BufferType.input_or_output, + btype, ) def populate_allocate_buffer_info(stmt): if isinstance(stmt, tvm.tir.stmt.Allocate): allocate = stmt - if "placeholder" in allocate.buffer_var.name: - storage_scope = allocate.buffer_var.name.split(".")[-1] - else: - storage_scope = "global" - + pointer_type = allocate.buffer_var.type_annotation + storage_scope = pointer_type.storage_scope if storage_scope == "local": - buffer_type = BufferType.shram - else: - buffer_type = BufferType.scratch - buffer_info[allocate.buffer_var] = BufferInfo( - None, - allocate.extents, - allocate.dtype, - buffer_type, - ) + buffer_info[allocate.buffer_var] = BufferInfo( + None, + allocate.extents, + allocate.dtype, + BufferType.shram, + ) tvm.tir.stmt_functor.post_order_visit(primfunc.body, populate_allocate_buffer_info) return buffer_info -def assign_addresses(buffer_info, npu_ops): +def assign_addresses(buffer_info, npu_ops, scratch_region_map): """This function will assign addresses to tensors within two buffers : scratch and constants. The scratch is the buffer created to hold all intermediary data @@ -272,14 +377,14 @@ def assign_addresses(buffer_info, npu_ops): The key is the buffer name to BufferInfo npu_ops : list A list of Vela NpuOps with tir.Loads for addresses + scratch_region_map : Dict[tvm.tir.Var, RegionOffset] + A buffer_var to region and offset map. Returns ------- npu_ops : list A list of Vela NpuOps with addesses within scratch and constant buffers constant_tensor : NDArray A unified constant data array of uint8 as the constant buffer - scratch_size : int - The size of the scratch tensor. """ def replace_npu_fm_with_address(npu_fm): @@ -290,21 +395,34 @@ def replace_npu_fm_with_address(npu_fm): assert npu_fm.tiles.addresses[1:] == [0, 0, 0] npu_fm.tiles.addresses[1:] = [0, 0, 0] buffer = npu_fm.tiles.addresses[0].buffer_var - assert buffer in buffer_addresses.keys() - address, buffer_type = buffer_addresses[buffer] + + if buffer in scratch_region_map.keys(): + address = scratch_region_map[buffer].offset + region = scratch_region_map[buffer].region + else: + assert buffer in buffer_addresses.keys() + address, buffer_type = buffer_addresses[buffer] + region = _get_region(buffer_type) + index = npu_fm.tiles.addresses[0].index * ( np.iinfo(np.dtype(npu_fm.tiles.addresses[0])).bits // 8 ) npu_fm.tiles.addresses[0] = address + int(index) - npu_fm.region = _REGION_MAP[buffer_type] + npu_fm.region = region return npu_fm def replace_npu_address_range_with_address(npu_addr_range): assert isinstance(npu_addr_range.address, tvm.tir.Load) buffer = npu_addr_range.address.buffer_var + if buffer in scratch_region_map.keys(): + return vapi.NpuAddressRange( + scratch_region_map[buffer].region, + scratch_region_map[buffer].offset, + npu_addr_range.length, + ) assert buffer in buffer_addresses.keys(), f"searching for buffer : {buffer}, but not found" address, buffer_type = buffer_addresses[buffer] - return vapi.NpuAddressRange(_REGION_MAP[buffer_type], address, npu_addr_range.length) + return vapi.NpuAddressRange(_get_region(buffer_type), address, npu_addr_range.length) def replace_tir_loads(npu_object): if isinstance(npu_object, vapi.NpuFeatureMap): @@ -325,7 +443,6 @@ def classify_io(buffer): raise ValueError(f"Unused IO : {buffer} in tir module.") - scratch_size = 0 constant_hex_data = [] total_constant_len = 0 buffer_addresses = dict() @@ -345,8 +462,10 @@ def classify_io(buffer): constant_hex_data.append(constant_tensor) total_constant_len += len(constant_tensor) // 2 else: - if info.btype == BufferType.input_or_output: - buffer_type = classify_io(_buffer) + if info.btype == BufferType.input_or_output or info.btype == BufferType.input: + buffer_type = info.btype + if info.btype == BufferType.input_or_output: + buffer_type = classify_io(_buffer) assert buffer_type in (BufferType.input, BufferType.output) address = 0 buffer_addresses[_buffer] = (address, buffer_type) @@ -359,14 +478,8 @@ def classify_io(buffer): address = arch_config.lut_start_address buffer_addresses[_buffer] = (address, info.btype) else: - dtype_bytes = np.iinfo(np.dtype(info.dtype)).bits // 8 - size_in_bytes = int(dtype_bytes * np.prod(list(info.shape))) - # Every memory address the NPU access have to be 16 byte aligned - size_in_bytes = util.round_up(size_in_bytes, 16) + # These buffer_vars are already updated in scratch_region_map assert info.btype == BufferType.scratch - address = scratch_size - scratch_size += size_in_bytes - buffer_addresses[_buffer] = (address, info.btype) for npu_op in npu_ops: for attr_name, attr in npu_op.__dict__.items(): @@ -379,11 +492,7 @@ def classify_io(buffer): setattr(npu_op, attr_name, replace_tir_loads(attr)) constant_data = "".join(constant_hex_data) - return ( - npu_ops, - constant_data, - scratch_size, - ) + return (npu_ops, constant_data) def translate_ethosu_tir_call_extern(tir_call_extern): @@ -733,17 +842,18 @@ def _create_npu_rounding_mode( def _create_npu_dma_op(serial_copy): """This is a helper function to capture the list of arguments to create a NpuDmaOperation object""" + data_type_bytes = np.iinfo(np.dtype(serial_copy.read_address.dtype)).bits // 8 src = vapi.NpuAddressRange( # region will be updated later region=0, address=serial_copy.read_address, - length=int(serial_copy.length.value), + length=int(serial_copy.length.value) * data_type_bytes, ) dest = vapi.NpuAddressRange( # region will be updated later region=0, address=serial_copy.write_address, - length=int(serial_copy.length.value), + length=int(serial_copy.length.value) * data_type_bytes, ) return vapi.NpuDmaOperation(src, dest) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index f076efeb4ac50..1d3c75a3e41d9 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -948,8 +948,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { tir_main_func->GetAttr>(tvm::attr::kPoolArgs); if (allocated_pool_infos) { for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { - pool_vars.push_back(allocated_pool_info->pool_var.value()); - pool_var_info.Set(allocated_pool_info->pool_var.value(), allocated_pool_info); + int pool_var_index = allocated_pool_info->pool_var_idx.value()->value; + pool_vars.push_back(tir_main_func->params[pool_var_index]); + pool_var_info.Set(tir_main_func->params[pool_var_index], allocated_pool_info); } } Array devices = ListDevices(); diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index fb4fb52c507e1..5bd5fa700ae22 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -264,11 +264,13 @@ void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { // process leaving them to as tir.allocate nodes for codegen. Additionally, the USMP can only work // with buffers that have global storage_scope - if (!current_scope_info.for_loop.defined()) { - RecordAllocateNodeInfo(op); - } else if (current_scope_info.for_loop.defined() && - current_scope_info.for_loop->kind == ForKind::kSerial && storage_scope == "global") { - RecordAllocateNodeInfo(op); + if (storage_scope == "global") { + if (!current_scope_info.for_loop.defined()) { + RecordAllocateNodeInfo(op); + } else if (current_scope_info.for_loop.defined() && + current_scope_info.for_loop->kind == ForKind::kSerial) { + RecordAllocateNodeInfo(op); + } } StmtExprVisitor::VisitStmt(op->body); current_scope_info.allocate_nodes.erase(GetRef(op)); diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc index 516ddd1a241bf..2af0a4be3f421 100644 --- a/src/tir/usmp/transform/assign_pool_info.cc +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -48,9 +48,7 @@ class PoolInfoAssigner : public StmtExprMutator { ICHECK(target_host) << "main function does not have a target attr"; Array pool_infos = module->GetAttr>(tvm::attr::kPoolInfoIRModuleAttr) - .value_or({usmp::PoolInfo("global_workspace", - {{target_host.value(), usmp::kTargetPoolReadWriteAccess}}, - usmp::kUnrestrictedPoolSizeHint, Bool(true))}); + .value_or({CreateDefaultMemoryPool(module)}); for (const usmp::PoolInfo& pool_info : pool_infos) { for (const auto& kv : pool_info->target_access) { Target tgt = kv.first; @@ -73,8 +71,21 @@ class PoolInfoAssigner : public StmtExprMutator { IRModule mod_; Map> target_pool_infos_; PrimFunc func_; + usmp::PoolInfo CreateDefaultMemoryPool(const IRModule& module); }; +usmp::PoolInfo PoolInfoAssigner::CreateDefaultMemoryPool(const tvm::IRModule& module) { + Map target_access; + for (const auto& kv : module->functions) { + BaseFunc func = kv.second; + Optional target = func->GetAttr(tvm::attr::kTarget); + ICHECK(target) << "main function does not have a target attr"; + target_access.Set(target.value(), usmp::kTargetPoolReadWriteAccess); + } + return usmp::PoolInfo("global_workspace", target_access, usmp::kUnrestrictedPoolSizeHint, + Bool(true)); +} + Stmt PoolInfoAssigner::VisitStmt_(const AllocateNode* op) { Optional tgt = func_->GetAttr(tvm::attr::kTarget).value(); ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_; diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index cd797681d4743..999ca37d21280 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -189,7 +189,7 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda si.params.push_back(pool_var); si.pools_to_params.Set(pool_info, pool_var); si.allocated_pool_params.push_back(AllocatedPoolInfo( - allocated_pool_info->pool_info, allocated_pool_info->allocated_size, pool_var)); + allocated_pool_info->pool_info, allocated_pool_info->allocated_size, si.params.size() - 1)); int pool_size = all_pools_sizes_[pool_info]; String buffer_var_name = pool_ref_name + "_buffer_var"; @@ -258,7 +258,7 @@ Array PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs( allocate_buf_to_let_var_.find(Downcast(arg)) != allocate_buf_to_let_var_.end()) { ret.push_back(allocate_buf_to_let_var_[Downcast(arg)]); } else { - ret.push_back(arg); + ret.push_back(VisitExpr(arg)); } } return ret; diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index 1fff70f5892e9..374d2c11d775d 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -141,12 +141,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var) { +AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, + Integer pool_var_idx) { auto allocated_poolinfo_node = make_object(); allocated_poolinfo_node->pool_info = pool_info; allocated_poolinfo_node->allocated_size = allocated_size; - if (pool_var.defined()) { - allocated_poolinfo_node->pool_var = pool_var; + if (pool_var_idx.defined()) { + allocated_poolinfo_node->pool_var_idx = pool_var_idx; } data_ = std::move(allocated_poolinfo_node); } diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 0b058a94fb608..50f0174f2b0bd 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -215,6 +215,7 @@ def create_test_runner(accel="ethos-u55-256"): "relay.ext.ethos-u.options": { "accelerator_config": accel, }, + "tir.usmp.enable": True, }, )