Skip to content

Commit

Permalink
[CodeGenC] Handle GlobalVar callee as internal function call (#15103)
Browse files Browse the repository at this point in the history
Analogous to #14901, treat GlobalVar callees as internal function calls in CodeGenC. This specific PR doesn't provide new end-to-end functionality, as the target="c" backend isn't compiled. It does lead into allowing subroutines in any target whose codegen derives from CodeGenC, which will depend on the single-module lowering flow in #14985.

* [CodeGenC] Added unit tests for desired behavior

* [CodeGenC] Handle GlobalVar callee as internal function call

* Update CodeGenC subclasses for updated interface

- Call `DeclareFunction` for each `PrimFunc`, prior to any
  `AddFunction` calls

- Provide both `GlobalVar` and `PrimFunc` to `AddFunction` calls.

* Updated CRT test to expect forward declaration

* Provide forward declarations for call_extern in cmsis

* Avoid duplicate forward declaration

C's automatic pointer cast (e.g. `void*` to `int*`) means that use of
the arguments to infer the function signature may be incorrect.  If a
`call_extern` refers to a function within the same module, only output
a single forward declaration based on the PrimFunc's parameters, not
based on the CallNode's arguments.

* Updated expected ptx cuda

* Cast the AOT pools to the arg type

* Improved tvm::GetType for tvm_access_ptr and address_of

These `Call` instances can return a
`PointerType(PrimType(pointee_dtype))` rather than a
`PrimType(DataType::Handle())`.

* [ARM][Topi] Update micro kernels to use same argument type as caller

Previously, the micro kernels for gemm, avg_pool, max_pool, and
tensordot relied on C's implicit type conversions for the arguments,
when the caller's argument types differ from the signature's parameter
types.  This works, except when the codegen has auto-generated a
forward declaration based on the caller's argument types, such as
during AOT, which then causes a conflicting definition.

Since the codegen cannot determine the functions names from the
`"pragma_import_c"` in order to suppress these forward declarations,
this conflict can be more easily resolved by updating the micro kernel
signatures.  The three types of mismatches are below.

- Use of `int` or `long` parameters, whose width may vary by compiler,
  instead of fixed-width types.

- TIR expecting the data array's integer type to also be used as an
  error code's return type, rather than the micro kernels' `int32_t`
  error code.

- Pointer conversion done during argument conversion.

Type conversions are done at the start of each micro kernel, to avoid
changing types that are used within the computational sections of each
micro kernel.

* Updated unit tests with private=True

Required for internal functions after PR #15214

* Docstring updates from review
  • Loading branch information
Lunderberg committed Aug 8, 2023
1 parent 34cacb0 commit 9ff71f4
Show file tree
Hide file tree
Showing 27 changed files with 591 additions and 297 deletions.
8 changes: 4 additions & 4 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _body():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
cc.dtype,
"int32",
f"{func_prefix}_{width}_{uniq_id}",
aa.access_ptr("r"),
cc.access_ptr("w"),
Expand All @@ -68,7 +68,7 @@ def _body():
def _reduce_reset():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
)
return ib.get()

Expand Down Expand Up @@ -113,8 +113,8 @@ def sum_impl(N, uniq_id):
__attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
int16_t *arr,
int16_t *res16,
long arr_offset,
int reset) {{
int32_t arr_offset,
int32_t reset) {{
int n;
int32_t *p32;
int32_t res = reset ? 0 : *res16;
Expand Down
87 changes: 70 additions & 17 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
int K,
int32_t K_arg,
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int k_base = (K / 4) * 4;
switch ( K % 4 ) {{
case 1:
Expand Down Expand Up @@ -200,7 +205,12 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -221,7 +231,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int16_t bb_pad[{bb_pad_size}];
int32_t retcode = 0;
Expand Down Expand Up @@ -265,9 +279,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
int K,
int32_t K_arg,
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int k_base = (K / 4) * 4;
switch ( K % 4 ) {{
case 1:
Expand Down Expand Up @@ -309,7 +328,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -327,7 +350,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int16_t bb_pad[{bb_pad_size}];
int32_t retcode = 0;
Expand Down Expand Up @@ -368,9 +395,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
int K,
int32_t K_arg,
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int k_base = (K / 2) * 2;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
Expand All @@ -387,7 +419,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -408,7 +444,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int32_t retcode = 0;
if ( {M} < 2 && {N} < 2 ) {{
Expand Down Expand Up @@ -450,9 +490,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
int K,
int32_t K_arg,
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int K = K_arg;
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int k_base = (K / 2) * 2;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
Expand All @@ -469,7 +514,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
Expand All @@ -487,7 +536,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
int A_stride = A_stride_arg;
int B_stride = B_stride_arg;
int C_stride = C_stride_arg;
int32_t retcode = 0;
if ( {M} < 2 && {N} < 2 ) {{
Expand Down Expand Up @@ -520,7 +573,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
cc[i*C_stride + j] = 0;
Expand Down
13 changes: 8 additions & 5 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _body():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
cc.dtype,
"int32",
f"{func_prefix}_{uniq_id}",
aa.access_ptr("r"),
cc.access_ptr("w"),
Expand All @@ -59,7 +59,7 @@ def _reduce_reset():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
"int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
)
)
return ib.get()
Expand Down Expand Up @@ -96,7 +96,7 @@ def max_impl(uniq_id):
#endif
__attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
int8_t *res,
int N) {{
int32_t N) {{
memset(res, (int8_t)-128, N * sizeof(*res));
return 0;
}}
Expand All @@ -107,7 +107,9 @@ def max_impl(uniq_id):
__attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
int8_t *arg,
int8_t *res,
int N) {{
int32_t N_arg) {{
int N = N_arg;
for ( int i = 0; i < N; ++ i )
if ( arg[i] > res[i] )
res[i] = arg[i];
Expand All @@ -120,7 +122,8 @@ def max_impl(uniq_id):
__attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
int8_t *arg,
int8_t *res,
int N) {{
int32_t N_arg) {{
int N = N_arg;
int32_t *parg32, *pres32;
int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3;
int32_t retcode = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,13 @@ def insert_lines(lines):
#define {function_name.upper()}_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t {function_name}(
int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
int32_t *bias, int32_t *scale
) {{
int32_t *output = output_arg;
int32_t *tensor = tensor_arg;
int32_t *kernel = kernel_arg;
{_init_biased_accumulators(num_outputs)}
{insert_lines(load_tensor_lines)}
Expand Down
28 changes: 12 additions & 16 deletions src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices);
}

/*!
* \brief Emit code that offloads a subgraph to the Cortex-M
*
* \return string of code that offloads a subgraph to the Cortex-M
*/
void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }

private:
/*! * \brief Enable storing the last error */
bool debug_last_error;
Expand Down Expand Up @@ -519,11 +512,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool emit_fwd_func_decl = false;
bool debug_last_error = GetCompilerAttrs()->debug_last_error;
CodeGenCMSISNN codegen;
Array<String> function_names;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error);
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
for (auto kv : mod->functions) {
funcs.push_back(kv);

std::vector<std::pair<tvm::GlobalVar, tvm::PrimFunc>> funcs;
for (auto [gvar, base_func] : mod->functions) {
funcs.push_back({gvar, Downcast<PrimFunc>(base_func)});
}

std::sort(funcs.begin(), funcs.end(),
Expand All @@ -538,13 +531,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
return name_hint_a < name_hint_b;
});

for (auto kv : funcs) {
auto prim_func = Downcast<PrimFunc>(kv.second);
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
function_names.push_back(global_symbol.value());
codegen.AddFunction(prim_func);
for (auto [gvar, prim_func] : funcs) {
codegen.AddFunction(gvar, prim_func);
}
std::string code = codegen.Finish();

Array<String> function_names;
for (auto [gvar, prim_func] : funcs) {
function_names.push_back(codegen.GetFunctionName(gvar));
}

return codegen::CSourceModuleCreate(code, "c", function_names);
}

Expand Down
26 changes: 20 additions & 6 deletions src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,30 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool emit_asserts = false;
bool emit_fwd_func_decl = false;
CodeGenExampleTargetHook codegen;
Array<String> function_names;

std::unordered_set<std::string> devices;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
for (auto kv : mod->functions) {
auto prim_func = Downcast<PrimFunc>(kv.second);
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
function_names.push_back(global_symbol.value());
codegen.AddFunction(prim_func);

Map<GlobalVar, PrimFunc> functions;
for (auto [gvar, base_func] : mod->functions) {
auto prim_func = Downcast<PrimFunc>(base_func);
functions.Set(gvar, prim_func);
}

for (auto [gvar, prim_func] : functions) {
codegen.DeclareFunction(gvar, prim_func);
}
for (auto [gvar, prim_func] : functions) {
codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
}

std::string code = codegen.Finish();

Array<String> function_names;
for (auto [gvar, prim_func] : functions) {
function_names.push_back(codegen.GetFunctionName(gvar));
}

return codegen::CSourceModuleCreate(code, "c", function_names);
}

Expand Down
Loading

0 comments on commit 9ff71f4

Please sign in to comment.