Skip to content

Commit

Permalink
Revert "[CodeGenC] Handle GlobalVar callee as internal function call" (
Browse files Browse the repository at this point in the history
…#15725)

Revert "[CodeGenC] Handle GlobalVar callee as internal function call (#15103)"

This reverts commit 9ff71f4, a recent change that breaks the Metal backend.
  • Loading branch information
junrushao committed Sep 12, 2023
1 parent e3055c1 commit e88d0d4
Show file tree
Hide file tree
Showing 27 changed files with 299 additions and 591 deletions.
8 changes: 4 additions & 4 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _body():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
"int32",
cc.dtype,
f"{func_prefix}_{width}_{uniq_id}",
aa.access_ptr("r"),
cc.access_ptr("w"),
Expand All @@ -68,7 +68,7 @@ def _body():
def _reduce_reset():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
)
return ib.get()

Expand Down Expand Up @@ -113,8 +113,8 @@ def sum_impl(N, uniq_id):
__attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
int16_t *arr,
int16_t *res16,
int32_t arr_offset,
int32_t reset) {{
long arr_offset,
int reset) {{
int n;
int32_t *p32;
int32_t res = reset ? 0 : *res16;
Expand Down
87 changes: 17 additions & 70 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
int32_t K_arg,
int K,
int8_t *aa, int8_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
int k_base = (K / 4) * 4;
switch ( K % 4 ) {{
case 1:
Expand Down Expand Up @@ -205,12 +200,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -231,11 +221,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
int16_t bb_pad[{bb_pad_size}];
int32_t retcode = 0;
Expand Down Expand Up @@ -279,14 +265,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
int32_t K_arg,
int K,
int8_t *aa, int8_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
int k_base = (K / 4) * 4;
switch ( K % 4 ) {{
case 1:
Expand Down Expand Up @@ -328,11 +309,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -350,11 +327,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
int16_t bb_pad[{bb_pad_size}];
int32_t retcode = 0;
Expand Down Expand Up @@ -395,14 +368,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
int32_t K_arg,
int K,
int16_t *aa, int16_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
int k_base = (K / 2) * 2;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
Expand All @@ -419,11 +387,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -444,11 +408,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
int32_t retcode = 0;
if ( {M} < 2 && {N} < 2 ) {{
Expand Down Expand Up @@ -490,14 +450,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
int32_t K_arg,
int K,
int16_t *aa, int16_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
int k_base = (K / 2) * 2;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
Expand All @@ -514,11 +469,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -536,11 +487,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int A_stride, int B_stride, int C_stride) {{
int32_t retcode = 0;
if ( {M} < 2 && {N} < 2 ) {{
Expand Down Expand Up @@ -573,7 +520,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
cc[i*C_stride + j] = 0;
Expand Down
13 changes: 5 additions & 8 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _body():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
"int32",
cc.dtype,
f"{func_prefix}_{uniq_id}",
aa.access_ptr("r"),
cc.access_ptr("w"),
Expand All @@ -59,7 +59,7 @@ def _reduce_reset():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
"int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
)
)
return ib.get()
Expand Down Expand Up @@ -96,7 +96,7 @@ def max_impl(uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
int8_t *res,
int32_t N) {{
int N) {{
memset(res, (int8_t)-128, N * sizeof(*res));
return 0;
}}
Expand All @@ -107,9 +107,7 @@ def max_impl(uniq_id):
__attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
int8_t *arg,
int8_t *res,
int32_t N_arg) {{
int N = N_arg;
int N) {{
for ( int i = 0; i < N; ++ i )
if ( arg[i] > res[i] )
res[i] = arg[i];
Expand All @@ -122,8 +120,7 @@ def max_impl(uniq_id):
__attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
int8_t *arg,
int8_t *res,
int32_t N_arg) {{
int N = N_arg;
int N) {{
int32_t *parg32, *pres32;
int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3;
int32_t retcode = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,8 @@ def insert_lines(lines):
#define {function_name.upper()}_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t {function_name}(
int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
int32_t *bias, int32_t *scale
int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
) {{
int32_t *output = output_arg;
int32_t *tensor = tensor_arg;
int32_t *kernel = kernel_arg;
{_init_biased_accumulators(num_outputs)}
{insert_lines(load_tensor_lines)}
Expand Down
28 changes: 16 additions & 12 deletions src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices);
}

/*!
* \brief Emit code that offloads a subgraph to the Cortex-M
*
* \return string of code that offloads a subgraph to the Cortex-M
*/
void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }

private:
/*! * \brief Enable storing the last error */
bool debug_last_error;
Expand Down Expand Up @@ -568,11 +575,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool emit_fwd_func_decl = false;
bool debug_last_error = GetCompilerAttrs()->debug_last_error;
CodeGenCMSISNN codegen;
Array<String> function_names;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error);

std::vector<std::pair<tvm::GlobalVar, tvm::PrimFunc>> funcs;
for (auto [gvar, base_func] : mod->functions) {
funcs.push_back({gvar, Downcast<PrimFunc>(base_func)});
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
for (auto kv : mod->functions) {
funcs.push_back(kv);
}

std::sort(funcs.begin(), funcs.end(),
Expand All @@ -587,16 +594,13 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
return name_hint_a < name_hint_b;
});

for (auto [gvar, prim_func] : funcs) {
codegen.AddFunction(gvar, prim_func);
for (auto kv : funcs) {
auto prim_func = Downcast<PrimFunc>(kv.second);
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
function_names.push_back(global_symbol.value());
codegen.AddFunction(prim_func);
}
std::string code = codegen.Finish();

Array<String> function_names;
for (auto [gvar, prim_func] : funcs) {
function_names.push_back(codegen.GetFunctionName(gvar));
}

return codegen::CSourceModuleCreate(code, "c", function_names);
}

Expand Down
26 changes: 6 additions & 20 deletions src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool emit_asserts = false;
bool emit_fwd_func_decl = false;
CodeGenExampleTargetHook codegen;

Array<String> function_names;
std::unordered_set<std::string> devices;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);

Map<GlobalVar, PrimFunc> functions;
for (auto [gvar, base_func] : mod->functions) {
auto prim_func = Downcast<PrimFunc>(base_func);
functions.Set(gvar, prim_func);
}

for (auto [gvar, prim_func] : functions) {
codegen.DeclareFunction(gvar, prim_func);
}
for (auto [gvar, prim_func] : functions) {
codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
for (auto kv : mod->functions) {
auto prim_func = Downcast<PrimFunc>(kv.second);
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
function_names.push_back(global_symbol.value());
codegen.AddFunction(prim_func);
}

std::string code = codegen.Finish();

Array<String> function_names;
for (auto [gvar, prim_func] : functions) {
function_names.push_back(codegen.GetFunctionName(gvar));
}

return codegen::CSourceModuleCreate(code, "c", function_names);
}

Expand Down

0 comments on commit e88d0d4

Please sign in to comment.