Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microTVM][zephyr] Add support for host-driven AoT execution on zephyr #11650

Merged
merged 4 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#define TVM_CRT_MAX_ARGS 10

/*! Size of the global function registry, in bytes. */
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512

/*! Maximum number of registered modules. */
#define TVM_CRT_MAX_REGISTERED_MODULES 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def _create_prj_conf(self, project_dir, options):
API_SERVER_CRT_LIBS_TOKEN = "<API_SERVER_CRT_LIBS>"

CRT_LIBS_BY_PROJECT_TYPE = {
"host_driven": "microtvm_rpc_server microtvm_rpc_common common",
"host_driven": "microtvm_rpc_server microtvm_rpc_common aot_executor_module aot_executor common",
"aot_demo": "memory microtvm_rpc_common common",
}

Expand Down
10 changes: 8 additions & 2 deletions python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

@register_error
class SessionTerminatedError(Exception):
"""Raised when a transport read operationd discovers that the remote session is terminated."""
"""Raised when a transport read operation discovers that the remote session is terminated."""


class Session:
Expand Down Expand Up @@ -86,12 +86,18 @@ def __init__(

self._rpc = None
self._graph_executor = None
self._enable_rpc_logger = False

self._exit_called = False

def get_system_lib(self):
return self._rpc.get_function("runtime.SystemLib")()

def get_aot_lib(self):
alanmacd marked this conversation as resolved.
Show resolved Hide resolved
return self._rpc.get_function("tvm.aot_executor.create")(
self.get_system_lib(), self.device, "default"
)

def _wrap_transport_read(self, n, timeout_microsec):
try:
return self.transport.read(
Expand Down Expand Up @@ -133,7 +139,7 @@ def __enter__(self):
int(timeouts.session_start_timeout_sec * 1e6),
int(timeouts.session_established_timeout_sec * 1e6),
self._cleanup,
False,
self._enable_rpc_logger,
)
)
self.device = self._rpc.cpu(0)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __setitem__(self, in_slice, value):
raise TypeError("type %s not supported" % str(type(value)))

def copyfrom(self, source_array):
"""Perform an synchronize copy from the array.
"""Perform a synchronized copy from the array.
alanmacd marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
Expand Down
12 changes: 10 additions & 2 deletions src/runtime/crt/aot_executor/aot_executor.c
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,29 @@ int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle,
for (i = 0; i < md->num_inputs; ++i) {
LOG_DEBUG("input allocate[%d]: %s\n", i, md->inputs[i].name);

TVMNDArray* array = &executor->args[arg_idx++];

status = TVMNDArray_Empty(md->inputs[i].num_shape, md->inputs[i].shape, md->inputs[i].dtype,
executor->device, &executor->args[arg_idx++]);
executor->device, array);
if (status != 0) {
return status;
}

TVMNDArray_IncrementReference(array);
}

for (i = 0; i < md->num_outputs; ++i) {
LOG_DEBUG("output allocate[%d]: %s\n", i, md->outputs[i].name);

TVMNDArray* array = &executor->args[arg_idx++];

status = TVMNDArray_Empty(md->outputs[i].num_shape, md->outputs[i].shape, md->outputs[i].dtype,
executor->device, &executor->args[arg_idx++]);
executor->device, array);
if (status != 0) {
return status;
}

TVMNDArray_IncrementReference(array);
}

for (i = 0; i < md->num_pools; ++i) {
Expand Down
30 changes: 24 additions & 6 deletions src/runtime/crt/aot_executor_module/aot_executor_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,27 @@ int32_t TVMAotExecutorModule_NotImplemented(TVMValue* args, int* tcodes, int nar

int32_t TVMAotExecutorModule_GetInput(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values,
int* ret_tcodes, void* resource_handle) {
int index = TVMAotExecutor_GetInputIndex(aot_executor.executor, args[0].v_str);
int64_t index;

if (index < 0) {
return kTvmErrorExecutorModuleNoSuchInput;
if (tcodes[0] == kTVMArgInt) {
if (args[0].v_int64 > TVMAotExecutor_GetNumInputs(aot_executor.executor)) {
return kTvmErrorFunctionCallInvalidArg;
}

index = args[0].v_int64;
} else {
index = TVMAotExecutor_GetInputIndex(aot_executor.executor, args[0].v_str);

if (index < 0) {
return kTvmErrorExecutorModuleNoSuchInput;
}
}

ret_values[0].v_handle = (void*)&aot_executor.executor->args[index].dl_tensor;
TVMNDArray* array = &aot_executor.executor->args[index];

TVMNDArray_IncrementReference(array);

ret_values[0].v_handle = (void*)(&array->dl_tensor);
ret_tcodes[0] = kTVMNDArrayHandle;

return 0;
Expand All @@ -103,9 +117,13 @@ int32_t TVMAotExecutorModule_GetOutput(TVMValue* args, int* tcodes, int nargs, T
}

// index past the input entries
int64_t idx = args[0].v_int64 + TVMAotExecutor_GetNumInputs(aot_executor.executor);
int64_t index = args[0].v_int64 + TVMAotExecutor_GetNumInputs(aot_executor.executor);

TVMNDArray* array = &aot_executor.executor->args[index];

TVMNDArray_IncrementReference(array);

ret_values[0].v_handle = (void*)&aot_executor.executor->args[idx].dl_tensor;
ret_values[0].v_handle = (void*)(&array->dl_tensor);
ret_tcodes[0] = kTVMNDArrayHandle;

return 0;
Expand Down
47 changes: 22 additions & 25 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_
}

int TVMArrayFree(TVMArrayHandle handle) {
TVMNDArray arr;
arr.dl_tensor = *handle;
return TVMNDArray_Release(&arr);
TVMNDArray* arr = (TVMNDArray*)handle;

return TVMNDArray_Release(arr);
}

int TVMDeviceAllocDataSpace(DLDevice dev, size_t nbytes, size_t alignment, DLDataType type_hint,
Expand Down Expand Up @@ -202,8 +202,8 @@ int TVMModFree(TVMModuleHandle mod) {
return 0;
}

int SystemLibraryCreate(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_codes) {
static int SystemLibraryCreate(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_codes) {
const TVMModule* system_lib;

if (system_lib_handle == kTVMModuleHandleUninitialized) {
Expand Down Expand Up @@ -400,8 +400,22 @@ int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMVal
return 0;
}

int TVMContribRandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code);
// Fill the tensor in args[0] with random data using TVMPlatformGenerateRandom.
static int RandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code) {
if (num_args != 1) {
return kTvmErrorFunctionCallNumArguments;
}

if (type_codes[0] != kTVMDLTensorHandle) {
return kTvmErrorFunctionCallWrongArgType;
}

DLTensor* tensor = (DLTensor*)args[0].v_handle;
TVMNDArray arr = {*tensor, 0};
return TVMNDArray_RandomFill(&arr);
}

tvm_crt_error_t TVMInitializeRuntime() {
int idx = 0;
tvm_crt_error_t error = kTvmErrorNoError;
Expand Down Expand Up @@ -440,7 +454,7 @@ tvm_crt_error_t TVMInitializeRuntime() {
}

if (error == kTvmErrorNoError) {
error = TVMFuncRegisterGlobal("tvm.contrib.random.random_fill", &TVMContribRandomFill, 0);
error = TVMFuncRegisterGlobal("tvm.contrib.random.random_fill", &RandomFill, 0);
}

if (error != kTvmErrorNoError) {
Expand Down Expand Up @@ -574,20 +588,3 @@ release_and_return : {
__attribute__((weak)) tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
return kTvmErrorFunctionCallNotImplemented;
}

// Fill the tensor in args[0] with random data using TVMPlatformGenerateRandom.
// Named to correspond with the analogous function in the C++ runtime.
int TVMContribRandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code) {
if (num_args != 1) {
return kTvmErrorFunctionCallNumArguments;
}

if (type_codes[0] != kTVMDLTensorHandle) {
return kTvmErrorFunctionCallWrongArgType;
}

DLTensor* tensor = (DLTensor*)args[0].v_handle;
TVMNDArray arr = {*tensor};
return TVMNDArray_RandomFill(&arr);
}
26 changes: 20 additions & 6 deletions src/runtime/crt/common/ndarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

#include "crt_config.h"

int TVMNDArray_Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev,
TVMNDArray* array) {
static int Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev,
TVMNDArray* array) {
memset(array, 0, sizeof(TVMNDArray));
array->dl_tensor.ndim = ndim;
tvm_crt_error_t err;
Expand All @@ -58,7 +58,7 @@ int64_t TVMNDArray_DataSizeBytes(TVMNDArray* array) {

int TVMNDArray_Empty(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev,
TVMNDArray* array) {
int status = TVMNDArray_Create(ndim, shape, dtype, dev, array);
int status = Create(ndim, shape, dtype, dev, array);
if (status != 0) {
return status;
}
Expand Down Expand Up @@ -132,7 +132,7 @@ int TVMNDArray_Load(TVMNDArray* ret, const char** strm) {

int TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, int32_t ndim, DLDataType dtype,
TVMNDArray* array_view) {
int status = TVMNDArray_Create(ndim, shape, dtype, arr->dl_tensor.device, array_view);
int status = Create(ndim, shape, dtype, arr->dl_tensor.device, array_view);
if (status != 0) {
return status;
}
Expand All @@ -149,21 +149,35 @@ int TVMNDArray_RandomFill(TVMNDArray* arr) {
return TVMPlatformGenerateRandom(arr->dl_tensor.data, (size_t)num_bytes);
}

void TVMNDArray_IncrementReference(TVMNDArray* arr) { arr->reference_count++; }

uint32_t TVMNDArray_DecrementReference(TVMNDArray* arr) {
if (arr->reference_count > 0) {
arr->reference_count--;
}

return arr->reference_count;
}

int TVMNDArray_Release(TVMNDArray* arr) {
tvm_crt_error_t err;
DLDevice dev = {kDLCPU, 0};

if (TVMNDArray_DecrementReference(arr) > 0) {
return 0;
}

err = TVMPlatformMemoryFree(arr->dl_tensor.data, dev);
if (err != kTvmErrorNoError) {
return err;
}
arr->dl_tensor.data = NULL;

arr->dl_tensor.data = 0;
err = TVMPlatformMemoryFree(arr->dl_tensor.shape, dev);
if (err != kTvmErrorNoError) {
return err;
}
arr->dl_tensor.shape = NULL;

arr->dl_tensor.shape = 0;
return 0;
}
4 changes: 3 additions & 1 deletion src/runtime/crt/graph_executor/graph_executor.c
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ int TVMGraphExecutor_SetupStorage(TVMGraphExecutor* executor) {
executor->storage_pool_count++;
}

// Assign the pooled entries. A unified memory pool is used to simplifiy
// Assign the pooled entries. A unified memory pool is used to simplify
// memory assignment for each node entry. The allocated memory on each device
// is mapped to this pool.
executor->data_entry_count = executor->node_row_ptr[executor->node_row_ptr_count - 1];
Expand All @@ -1031,6 +1031,8 @@ int TVMGraphExecutor_SetupStorage(TVMGraphExecutor* executor) {
attrs->shape + idx * TVM_CRT_MAX_NDIM, attrs->ndim[idx],
vtype[idx], &executor->data_entry[idx]);
CHECK_EQ(status, 0, "fail to create for node with idx=%d, storage_id=%u\n", idx, storage_id);

TVMNDArray_IncrementReference(&executor->data_entry[idx]);
}

// Release memory
Expand Down
13 changes: 11 additions & 2 deletions src/runtime/crt/graph_executor_module/graph_executor_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ int32_t TVMGraphExecutorModule_GetInput(TVMValue* args, int* tcodes, int nargs,

uint32_t eid = TVMGraphExecutor_GetEntryId(graph_executor.executor,
graph_executor.executor->input_nodes[index], 0);
ret_values[0].v_handle = (void*)&graph_executor.executor->data_entry[eid].dl_tensor;

TVMNDArray* array = &graph_executor.executor->data_entry[eid];

TVMNDArray_IncrementReference(array);

ret_values[0].v_handle = (void*)(&array->dl_tensor);
ret_tcodes[0] = kTVMNDArrayHandle;
return 0;
}
Expand Down Expand Up @@ -158,7 +163,11 @@ int32_t TVMGraphExecutorModule_GetOutput(TVMValue* args, int* tcodes, int nargs,
uint32_t index = graph_executor.executor->outputs[output_index].index;
uint32_t eid = TVMGraphExecutor_GetEntryId(graph_executor.executor, nid, index);

ret_values[0].v_handle = (void*)&(graph_executor.executor->data_entry[eid].dl_tensor);
TVMNDArray* array = &graph_executor.executor->data_entry[eid];

TVMNDArray_IncrementReference(array);

ret_values[0].v_handle = (void*)(&array->dl_tensor);
ret_tcodes[0] = kTVMNDArrayHandle;
return 0;
}
Expand Down
3 changes: 0 additions & 3 deletions src/runtime/crt/host/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ int main(int argc, char** argv) {
"failed to register GraphExecutor TVMModule");
#endif

CHECK_EQ(TVMAotExecutorModule_Register(), kTvmErrorNoError,
"failed to register AoT Executor TVMModule");

int error = TVMFuncRegisterGlobal("tvm.testing.reset_server",
(TVMFunctionHandle)&testonly_reset_server, 0);
if (error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ static const uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
static const uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;

typedef struct TVMNDArray {
/*! \brief the actual tensor in DLPack format. NOTE: this must be first element in struct */
DLTensor dl_tensor;

/*! \brief count of references to TVMNDArray to avoid early freeing by host */
uint32_t reference_count;
} TVMNDArray;

int TVMNDArray_Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev,
Expand All @@ -56,6 +60,10 @@ int TVMNDArray_Load(TVMNDArray* ret, const char** strm);
int TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, int32_t ndim, DLDataType dtype,
TVMNDArray* array_view);

void TVMNDArray_IncrementReference(TVMNDArray* arr);

uint32_t TVMNDArray_DecrementReference(TVMNDArray* arr);

int TVMNDArray_Release(TVMNDArray* arr);

#endif // TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_COMMON_NDARRAY_H_
6 changes: 6 additions & 0 deletions src/runtime/crt/microtvm_rpc_server/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#define DMLC_CMAKE_LITTLE_ENDIAN DMLC_IO_USE_LITTLE_ENDIAN
#define DMLC_LITTLE_ENDIAN 1
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/crt/aot_executor_module.h>
#include <tvm/runtime/crt/crt.h>
#include <tvm/runtime/crt/logging.h>
#include <tvm/runtime/crt/microtvm_rpc_server.h>
Expand Down Expand Up @@ -207,6 +208,11 @@ microtvm_rpc_server_t MicroTVMRpcServerInit(microtvm_rpc_channel_write_t write_f
TVMPlatformAbort(err);
}

err = TVMAotExecutorModule_Register();
if (err != kTvmErrorNoError) {
TVMPlatformAbort(err);
}

DLDevice dev = {kDLCPU, 0};
void* receive_buffer_memory;
err = TVMPlatformMemoryAllocate(TVM_CRT_MAX_PACKET_SIZE_BYTES, dev, &receive_buffer_memory);
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/graph_executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct TVMOpParam {
/*!
* \brief Tiny graph executor.
*
* This runtime can be acccesibly in various language via
* This runtime can be accessible in various languages via
* TVM runtime PackedFunc API.
*/
class TVM_DLL GraphExecutor : public ModuleNode {
Expand Down
Loading