diff --git a/include/tvm/runtime/crt/func_registry.h b/include/tvm/runtime/crt/func_registry.h index 4f8a19af591e..50737f871798 100644 --- a/include/tvm/runtime/crt/func_registry.h +++ b/include/tvm/runtime/crt/func_registry.h @@ -42,7 +42,7 @@ typedef struct TVMFuncRegistry { /*! \brief Names of registered functions, concatenated together and separated by \0. * An additional \0 is present at the end of the concatenated blob to mark the end. * - * Byte 0 is the number of functions in `funcs`. + * Byte 0 and 1 are the number of functions in `funcs`. */ const char* names; @@ -50,6 +50,31 @@ typedef struct TVMFuncRegistry { const TVMBackendPackedCFunc* funcs; } TVMFuncRegistry; +/*! + * \brief Get the of the number of functions from registry. + * + * \param reg TVMFunctionRegistry instance that contains the function. + * \return The number of functions from registry. + */ +uint16_t TVMFuncRegistry_GetNumFuncs(const TVMFuncRegistry* reg); + +/*! + * \brief Set the number of functions to registry. + * + * \param reg TVMFunctionRegistry instance that contains the function. + * \param num_funcs The number of functions + * \return 0 when successful. + */ +int TVMFuncRegistry_SetNumFuncs(const TVMFuncRegistry* reg, const uint16_t num_funcs); + +/*! + * \brief Get the address of 0th function from registry. + * + * \param reg TVMFunctionRegistry instance that contains the function. + * \return the address of 0th function from registry + */ +const char* TVMFuncRegistry_Get0thFunctionName(const TVMFuncRegistry* reg); + /*! * \brief Get packed function from registry by name. * diff --git a/src/runtime/crt/common/func_registry.c b/src/runtime/crt/common/func_registry.c index 116a5c496f1b..49cef8fd70eb 100644 --- a/src/runtime/crt/common/func_registry.c +++ b/src/runtime/crt/common/func_registry.c @@ -60,14 +60,29 @@ int strcmp_cursor(const char** cursor, const char* name) { return return_value; } +uint16_t TVMFuncRegistry_GetNumFuncs(const TVMFuncRegistry* reg) { + uint16_t num_funcs; + memcpy(&num_funcs, reg->names, sizeof(num_funcs)); + return num_funcs; +} + +int TVMFuncRegistry_SetNumFuncs(const TVMFuncRegistry* reg, const uint16_t num_funcs) { + memcpy((char*)reg->names, &num_funcs, sizeof(num_funcs)); + return 0; +} + +const char* TVMFuncRegistry_Get0thFunctionName(const TVMFuncRegistry* reg) { + // NOTE: first function name starts at index 2 to skip num_funcs. + return (reg->names + sizeof(uint16_t)); +} + tvm_crt_error_t TVMFuncRegistry_Lookup(const TVMFuncRegistry* reg, const char* name, tvm_function_index_t* function_index) { tvm_function_index_t idx; - const char* reg_name_ptr; + const char* reg_name_ptr = TVMFuncRegistry_Get0thFunctionName(reg); idx = 0; - // NOTE: reg_name_ptr starts at index 1 to skip num_funcs. - for (reg_name_ptr = reg->names + 1; *reg_name_ptr != '\0'; reg_name_ptr++) { + for (; *reg_name_ptr != '\0'; reg_name_ptr++) { if (!strcmp_cursor(®_name_ptr, name)) { *function_index = idx; return kTvmErrorNoError; @@ -82,9 +97,9 @@ tvm_crt_error_t TVMFuncRegistry_Lookup(const TVMFuncRegistry* reg, const char* n tvm_crt_error_t TVMFuncRegistry_GetByIndex(const TVMFuncRegistry* reg, tvm_function_index_t function_index, TVMBackendPackedCFunc* out_func) { - uint8_t num_funcs; + uint16_t num_funcs; - num_funcs = reg->names[0]; + num_funcs = TVMFuncRegistry_GetNumFuncs(reg); if (function_index >= num_funcs) { return kTvmErrorFunctionIndexInvalid; } @@ -101,7 +116,8 @@ tvm_crt_error_t TVMMutableFuncRegistry_Create(TVMMutableFuncRegistry* reg, uint8 reg->registry.names = (const char*)buffer; buffer[0] = 0; // number of functions present in buffer. - buffer[1] = 0; // end of names list marker. + buffer[1] = 0; // note that we combine the first two elements to form a 16-bit function index. + buffer[2] = 0; // end of names list marker. // compute a guess of the average size of one entry: // - assume average function name is around ~10 bytes @@ -117,13 +133,12 @@ tvm_crt_error_t TVMMutableFuncRegistry_Create(TVMMutableFuncRegistry* reg, uint8 tvm_crt_error_t TVMMutableFuncRegistry_Set(TVMMutableFuncRegistry* reg, const char* name, TVMBackendPackedCFunc func, int override) { size_t idx; - char* reg_name_ptr; + char* reg_name_ptr = (char*)TVMFuncRegistry_Get0thFunctionName(&(reg->registry)); idx = 0; // NOTE: safe to discard const qualifier here, since reg->registry.names was set from // TVMMutableFuncRegistry_Create above. - // NOTE: reg_name_ptr starts at index 1 to skip num_funcs. - for (reg_name_ptr = (char*)reg->registry.names + 1; *reg_name_ptr != 0; reg_name_ptr++) { + for (; *reg_name_ptr != 0; reg_name_ptr++) { if (!strcmp_cursor((const char**)®_name_ptr, name)) { if (override == 0) { return kTvmErrorFunctionAlreadyDefined; @@ -149,7 +164,11 @@ tvm_crt_error_t TVMMutableFuncRegistry_Set(TVMMutableFuncRegistry* reg, const ch reg_name_ptr += name_len + 1; *reg_name_ptr = 0; ((TVMBackendPackedCFunc*)reg->registry.funcs)[idx] = func; - ((char*)reg->registry.names)[0]++; // increment num_funcs. + + uint16_t num_funcs; + // increment num_funcs. + num_funcs = TVMFuncRegistry_GetNumFuncs(&(reg->registry)) + 1; + TVMFuncRegistry_SetNumFuncs(&(reg->registry), num_funcs); return kTvmErrorNoError; } diff --git a/src/target/func_registry_generator.cc b/src/target/func_registry_generator.cc index 7c948d50cbb9..d679bf379b62 100644 --- a/src/target/func_registry_generator.cc +++ b/src/target/func_registry_generator.cc @@ -31,7 +31,13 @@ namespace target { std::string GenerateFuncRegistryNames(const Array& function_names) { std::stringstream ss; - ss << (unsigned char)(function_names.size()); + + unsigned char function_nums[sizeof(uint16_t)]; + *reinterpret_cast(function_nums) = function_names.size(); + for (auto f : function_nums) { + ss << f; + } + for (auto f : function_names) { ss << f << '\0'; } diff --git a/tests/crt/func_registry_test.cc b/tests/crt/func_registry_test.cc index 9f0e7f8d1a5a..5962a3acee39 100644 --- a/tests/crt/func_registry_test.cc +++ b/tests/crt/func_registry_test.cc @@ -82,7 +82,7 @@ TEST(StrCmpScan, Test) { } TEST(FuncRegistry, Empty) { - TVMFuncRegistry registry{"\000", NULL}; + TVMFuncRegistry registry{"\000\000", NULL}; EXPECT_EQ(kTvmErrorFunctionNameNotFound, TVMFuncRegistry_Lookup(®istry, "foo", NULL)); EXPECT_EQ(kTvmErrorFunctionIndexInvalid, @@ -101,7 +101,7 @@ static int Bar(TVMValue* args, int* type_codes, int num_args, TVMValue* out_ret_ } // Matches the style of registry defined in generated C modules. -const char* kBasicFuncNames = "\002Foo\0Bar\0"; // NOTE: final \0 +const char* kBasicFuncNames = "\002\000Foo\0Bar\0"; // NOTE: final \0 const TVMBackendPackedCFunc funcs[2] = {&Foo, &Bar}; const TVMFuncRegistry kConstRegistry = {kBasicFuncNames, (const TVMBackendPackedCFunc*)funcs}; @@ -111,7 +111,8 @@ TEST(FuncRegistry, ConstGlobalRegistry) { // Foo EXPECT_EQ(kBasicFuncNames[0], 2); - EXPECT_EQ(kBasicFuncNames[1], 'F'); + EXPECT_EQ(kBasicFuncNames[1], 0); + EXPECT_EQ(kBasicFuncNames[2], 'F'); EXPECT_EQ(kTvmErrorNoError, TVMFuncRegistry_Lookup(&kConstRegistry, "Foo", &func_index)); EXPECT_EQ(0, func_index);