Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeGenC][Redo] Handle GlobalVar callee as internal function call #15835

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _body():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
cc.dtype,
"int32",
f"{func_prefix}_{width}_{uniq_id}",
aa.access_ptr("r"),
cc.access_ptr("w"),
Expand All @@ -68,7 +68,7 @@ def _body():
def _reduce_reset():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
)
return ib.get()

Expand Down Expand Up @@ -113,8 +113,8 @@ def sum_impl(N, uniq_id):
__attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
int16_t *arr,
int16_t *res16,
long arr_offset,
int reset) {{
int32_t arr_offset,
int32_t reset) {{
int n;
int32_t *p32;
int32_t res = reset ? 0 : *res16;
Expand Down
87 changes: 70 additions & 17 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
int K,
int32_t K_arg,
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int k_base = (K / 4) * 4;
switch ( K % 4 ) {{
case 1:
Expand Down Expand Up @@ -200,7 +205,12 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -221,7 +231,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int16_t bb_pad[{bb_pad_size}];
int32_t retcode = 0;
Expand Down Expand Up @@ -265,9 +279,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
int K,
int32_t K_arg,
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int k_base = (K / 4) * 4;
switch ( K % 4 ) {{
case 1:
Expand Down Expand Up @@ -309,7 +328,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -327,7 +350,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int16_t bb_pad[{bb_pad_size}];
int32_t retcode = 0;
Expand Down Expand Up @@ -368,9 +395,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
int K,
int32_t K_arg,
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int k_base = (K / 2) * 2;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
Expand All @@ -387,7 +419,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -408,7 +444,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int32_t retcode = 0;
if ( {M} < 2 && {N} < 2 ) {{
Expand Down Expand Up @@ -450,9 +490,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
int K,
int32_t K_arg,
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int k_base = (K / 2) * 2;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
Expand All @@ -469,7 +514,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -487,7 +536,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int32_t retcode = 0;
if ( {M} < 2 && {N} < 2 ) {{
Expand Down Expand Up @@ -520,7 +573,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
cc[i*C_stride + j] = 0;
Expand Down
13 changes: 8 additions & 5 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _body():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
cc.dtype,
"int32",
f"{func_prefix}_{uniq_id}",
aa.access_ptr("r"),
cc.access_ptr("w"),
Expand All @@ -59,7 +59,7 @@ def _reduce_reset():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
"int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
)
)
return ib.get()
Expand Down Expand Up @@ -96,7 +96,7 @@ def max_impl(uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
int8_t *res,
int N) {{
int32_t N) {{
memset(res, (int8_t)-128, N * sizeof(*res));
return 0;
}}
Expand All @@ -107,7 +107,9 @@ def max_impl(uniq_id):
__attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
int8_t *arg,
int8_t *res,
int N) {{
int32_t N_arg) {{
int N = N_arg;
for ( int i = 0; i < N; ++ i )
if ( arg[i] > res[i] )
res[i] = arg[i];
Expand All @@ -120,7 +122,8 @@ def max_impl(uniq_id):
__attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
int8_t *arg,
int8_t *res,
int N) {{
int32_t N_arg) {{
int N = N_arg;
int32_t *parg32, *pres32;
int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3;
int32_t retcode = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,13 @@ def insert_lines(lines):
#define {function_name.upper()}_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t {function_name}(
int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
int32_t *bias, int32_t *scale
) {{
int32_t *output = output_arg;
int32_t *tensor = tensor_arg;
int32_t *kernel = kernel_arg;
{_init_biased_accumulators(num_outputs)}
{insert_lines(load_tensor_lines)}
Expand Down
28 changes: 12 additions & 16 deletions src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices);
}

/*!
* \brief Emit code that offloads a subgraph to the Cortex-M
*
* \return string of code that offloads a subgraph to the Cortex-M
*/
void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }

private:
/*! * \brief Enable storing the last error */
bool debug_last_error;
Expand Down Expand Up @@ -575,11 +568,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool emit_fwd_func_decl = false;
bool debug_last_error = GetCompilerAttrs()->debug_last_error;
CodeGenCMSISNN codegen;
Array<String> function_names;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error);
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
for (auto kv : mod->functions) {
funcs.push_back(kv);

std::vector<std::pair<tvm::GlobalVar, tvm::PrimFunc>> funcs;
for (auto [gvar, base_func] : mod->functions) {
funcs.push_back({gvar, Downcast<PrimFunc>(base_func)});
}

std::sort(funcs.begin(), funcs.end(),
Expand All @@ -594,13 +587,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
return name_hint_a < name_hint_b;
});

for (auto kv : funcs) {
auto prim_func = Downcast<PrimFunc>(kv.second);
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
function_names.push_back(global_symbol.value());
codegen.AddFunction(prim_func);
for (auto [gvar, prim_func] : funcs) {
codegen.AddFunction(gvar, prim_func);
}
std::string code = codegen.Finish();

Array<String> function_names;
for (auto [gvar, prim_func] : funcs) {
function_names.push_back(codegen.GetFunctionName(gvar));
}

return codegen::CSourceModuleCreate(code, "c", function_names);
}

Expand Down
26 changes: 20 additions & 6 deletions src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,30 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool emit_asserts = false;
bool emit_fwd_func_decl = false;
CodeGenExampleTargetHook codegen;
Array<String> function_names;

std::unordered_set<std::string> devices;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
for (auto kv : mod->functions) {
auto prim_func = Downcast<PrimFunc>(kv.second);
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
function_names.push_back(global_symbol.value());
codegen.AddFunction(prim_func);

Map<GlobalVar, PrimFunc> functions;
for (auto [gvar, base_func] : mod->functions) {
auto prim_func = Downcast<PrimFunc>(base_func);
functions.Set(gvar, prim_func);
}

for (auto [gvar, prim_func] : functions) {
codegen.DeclareFunction(gvar, prim_func);
}
for (auto [gvar, prim_func] : functions) {
codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
}

std::string code = codegen.Finish();

Array<String> function_names;
for (auto [gvar, prim_func] : functions) {
function_names.push_back(codegen.GetFunctionName(gvar));
}

return codegen::CSourceModuleCreate(code, "c", function_names);
}

Expand Down