diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py index 3eb32d8fdb16..e8e45152aae7 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py @@ -55,7 +55,7 @@ def _body(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( - "int32", + cc.dtype, f"{func_prefix}_{width}_{uniq_id}", aa.access_ptr("r"), cc.access_ptr("w"), @@ -68,7 +68,7 @@ def _body(): def _reduce_reset(): ib = tvm.tir.ir_builder.create() ib.emit( - tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w")) + tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w")) ) return ib.get() @@ -113,8 +113,8 @@ def sum_impl(N, uniq_id): __attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}( int16_t *arr, int16_t *res16, - int32_t arr_offset, - int32_t reset) {{ + long arr_offset, + int reset) {{ int n; int32_t *p32; int32_t res = reset ? 0 : *res16; diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py index e26e818fbd7e..929dcc6557ff 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py @@ -156,14 +156,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}( - int32_t K_arg, + int K, int8_t *aa, int8_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int K = K_arg; - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - + int A_stride, int B_stride, int C_stride) {{ int k_base = (K / 4) * 4; switch ( K % 4 ) {{ case 1: @@ -205,12 +200,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - - + int A_stride, int B_stride, int C_stride) {{ for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -231,11 +221,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - + int A_stride, int B_stride, int C_stride) {{ int16_t bb_pad[{bb_pad_size}]; int32_t retcode = 0; @@ -279,14 +265,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}( - int32_t K_arg, + int K, int8_t *aa, int8_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int K = K_arg; - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - + int A_stride, int B_stride, int C_stride) {{ int k_base = (K / 4) * 4; switch ( K % 4 ) {{ case 1: @@ -328,11 +309,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - + int A_stride, int B_stride, int C_stride) {{ for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -350,11 +327,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - + int A_stride, int B_stride, int C_stride) {{ int16_t bb_pad[{bb_pad_size}]; int32_t retcode = 0; @@ -395,14 +368,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}( - int32_t K_arg, + int K, int16_t *aa, int16_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int K = K_arg; - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - + int A_stride, int B_stride, int C_stride) {{ int k_base = (K / 2) * 2; for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ @@ -419,11 +387,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - + int A_stride, int B_stride, int C_stride) {{ for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -444,11 +408,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - + int A_stride, int B_stride, int C_stride) {{ int32_t retcode = 0; if ( {M} < 2 && {N} < 2 ) {{ @@ -490,14 +450,9 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}( - int32_t K_arg, + int K, int16_t *aa, int16_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int K = K_arg; - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - + int A_stride, int B_stride, int C_stride) {{ int k_base = (K / 2) * 2; for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ @@ -514,11 +469,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - + int A_stride, int B_stride, int C_stride) {{ for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -536,11 +487,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ - int A_stride = A_stride_arg; - int B_stride = B_stride_arg; - int C_stride = C_stride_arg; - + int A_stride, int B_stride, int C_stride) {{ int32_t retcode = 0; if ( {M} < 2 && {N} < 2 ) {{ @@ -573,7 +520,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #ifdef __cplusplus extern "C" #endif -__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{ +__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{ for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ cc[i*C_stride + j] = 0; diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py index cfed417c9fe7..66d712a4a0a2 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py @@ -46,7 +46,7 @@ def _body(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( - "int32", + cc.dtype, f"{func_prefix}_{uniq_id}", aa.access_ptr("r"), cc.access_ptr("w"), @@ -59,7 +59,7 @@ def _reduce_reset(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( - "int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0] + cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0] ) ) return ib.get() @@ -96,7 +96,7 @@ def max_impl(uniq_id): #endif __attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}( int8_t *res, - int32_t N) {{ + int N) {{ memset(res, (int8_t)-128, N * sizeof(*res)); return 0; }} @@ -107,9 +107,7 @@ def max_impl(uniq_id): __attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}( int8_t *arg, int8_t *res, - int32_t N_arg) {{ - int N = N_arg; - + int N) {{ for ( int i = 0; i < N; ++ i ) if ( arg[i] > res[i] ) res[i] = arg[i]; @@ -122,8 +120,7 @@ def max_impl(uniq_id): __attribute__((always_inline)) static inline int32_t max8_{uniq_id}( int8_t *arg, int8_t *res, - int32_t N_arg) {{ - int N = N_arg; + int N) {{ int32_t *parg32, *pres32; int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3; int32_t retcode = 0; diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py index af3b23e01dcb..d2a8f1ef6905 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py @@ -390,13 +390,8 @@ def insert_lines(lines): #define {function_name.upper()}_EXISTS #include __attribute__((always_inline)) static inline int32_t {function_name}( - int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, - int32_t *bias, int32_t *scale + int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale ) {{ - int32_t *output = output_arg; - int32_t *tensor = tensor_arg; - int32_t *kernel = kernel_arg; - {_init_biased_accumulators(num_outputs)} {insert_lines(load_tensor_lines)} diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index 6febfe3486af..186fa30f201d 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -46,6 +46,13 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices); } + /*! + * \brief Emit code that offloads a subgraph to the Cortex-M + * + * \return string of code that offloads a subgraph to the Cortex-M + */ + void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); } + private: /*! * \brief Enable storing the last error */ bool debug_last_error; @@ -568,11 +575,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { bool emit_fwd_func_decl = false; bool debug_last_error = GetCompilerAttrs()->debug_last_error; CodeGenCMSISNN codegen; + Array function_names; codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error); - - std::vector> funcs; - for (auto [gvar, base_func] : mod->functions) { - funcs.push_back({gvar, Downcast(base_func)}); + std::vector> funcs; + for (auto kv : mod->functions) { + funcs.push_back(kv); } std::sort(funcs.begin(), funcs.end(), @@ -587,16 +594,13 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { return name_hint_a < name_hint_b; }); - for (auto [gvar, prim_func] : funcs) { - codegen.AddFunction(gvar, prim_func); + for (auto kv : funcs) { + auto prim_func = Downcast(kv.second); + auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + function_names.push_back(global_symbol.value()); + codegen.AddFunction(prim_func); } std::string code = codegen.Finish(); - - Array function_names; - for (auto [gvar, prim_func] : funcs) { - function_names.push_back(codegen.GetFunctionName(gvar)); - } - return codegen::CSourceModuleCreate(code, "c", function_names); } diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc index 6f09e0a0c3f0..0db8d06c3143 100644 --- a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc +++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc @@ -49,30 +49,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { bool emit_asserts = false; bool emit_fwd_func_decl = false; CodeGenExampleTargetHook codegen; - + Array function_names; std::unordered_set devices; codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); - - Map functions; - for (auto [gvar, base_func] : mod->functions) { - auto prim_func = Downcast(base_func); - functions.Set(gvar, prim_func); - } - - for (auto [gvar, prim_func] : functions) { - codegen.DeclareFunction(gvar, prim_func); - } - for (auto [gvar, prim_func] : functions) { - codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl); + for (auto kv : mod->functions) { + auto prim_func = Downcast(kv.second); + auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + function_names.push_back(global_symbol.value()); + codegen.AddFunction(prim_func); } - std::string code = codegen.Finish(); - - Array function_names; - for (auto [gvar, prim_func] : functions) { - function_names.push_back(codegen.GetFunctionName(gvar)); - } - return codegen::CSourceModuleCreate(code, "c", function_names); } diff --git a/src/relay/backend/contrib/uma/tir_to_runtime.cc b/src/relay/backend/contrib/uma/tir_to_runtime.cc index 487e247f5d38..3b58fda54b52 100644 --- a/src/relay/backend/contrib/uma/tir_to_runtime.cc +++ b/src/relay/backend/contrib/uma/tir_to_runtime.cc @@ -49,6 +49,13 @@ class UMACodegen : public codegen::CodeGenCHost { CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str_, devices); } + /*! + * \brief Emit code that offloads a subgraph to the UMA target + * + * \return string of code that offloads a subgraph to the UMA target + */ + void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); } + private: String target_str_; }; @@ -56,30 +63,17 @@ class UMACodegen : public codegen::CodeGenCHost { runtime::Module TIRToRuntime(IRModule mod, Target target) { bool output_ssa = false; bool emit_asserts = false; - bool emit_fwd_func_decl = true; + bool emit_fwd_func_decl = false; UMACodegen codegen(target->kind->name); + Array function_names; codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl); - - Map functions; - for (auto [gvar, base_func] : mod->functions) { - auto prim_func = Downcast(base_func); - functions.Set(gvar, prim_func); - } - - for (auto [gvar, prim_func] : functions) { - codegen.DeclareFunction(gvar, prim_func); - } - for (auto [gvar, prim_func] : functions) { - codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl); + for (auto kv : mod->functions) { + auto prim_func = Downcast(kv.second); + auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + function_names.push_back(global_symbol.value()); + codegen.AddFunction(prim_func); } - std::string code = codegen.Finish(); - - Array function_names; - for (auto [gvar, prim_func] : functions) { - function_names.push_back(codegen.GetFunctionName(gvar)); - } - return codegen::CSourceModuleCreate(code, "c", function_names); } diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index e0f53e350992..1c0b5094efab 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -131,21 +131,13 @@ runtime::Module BuildCUDA(IRModule mod, Target target) { CodeGenCUDA cg; cg.Init(output_ssa); - Map functions; - for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; - auto prim_func = Downcast(base_func); - auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - functions.Set(gvar, prim_func); - } - - for (auto [gvar, prim_func] : functions) { - cg.DeclareFunction(gvar, prim_func); - } - for (auto [gvar, prim_func] : functions) { - cg.AddFunction(gvar, prim_func); + cg.AddFunction(f); } std::string code = cg.Finish(); diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index dc3ba0875161..700d85b4ccd4 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -40,22 +40,13 @@ runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) { CodeGenOpenCL cg; cg.Init(output_ssa); - Map functions; - for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodegenOpenCL: Can only take PrimFunc"; - auto prim_func = Downcast(base_func); - auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodegenOpenCL: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - functions.Set(gvar, prim_func); - } - - for (auto [gvar, prim_func] : functions) { - cg.DeclareFunction(gvar, prim_func); - } - - for (auto [gvar, prim_func] : functions) { - cg.AddFunction(gvar, prim_func); + cg.AddFunction(f); } std::string code = cg.Finish(); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 187bdc74fe29..a7cc320562cb 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -42,7 +42,6 @@ void CodeGenC::InitFuncState(const PrimFunc& f) { alloc_storage_scope_.clear(); handle_data_type_.clear(); CodeGenSourceBase::ClearFuncState(); - ReserveKeywordsAsUnique(); } void CodeGenC::ReserveKeywordsAsUnique() { @@ -76,92 +75,51 @@ void CodeGenC::ReserveKeywordsAsUnique() { name_supply_->ReserveName("return"); } -void CodeGenC::PrintFunctionSignature(const String& function_name, const PrimFunc& func, - std::ostream& os) { - PrintFuncPrefix(os); - PrintType(func->ret_type, os); - PrintExtraAttrs(func, os); - os << " " << function_name << "("; - for (size_t i = 0; i < func->params.size(); ++i) { - tir::Var v = func->params[i]; - - if (i > 0) { - os << ", "; - } - - if (auto it = alloc_storage_scope_.find(v.get()); it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, os); - } - - PrintType(GetType(v), os); - - bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias); - bool is_handle = v.dtype().is_handle(); - if (no_alias && is_handle) { - PrintRestrict(v, os); - } - - os << " " << AllocVarID(v.get()); - } - os << ")"; +void CodeGenC::AddFunction(const PrimFunc& f) { + // clear previous generated state. + this->InitFuncState(f); + // reserve keywords + ReserveKeywordsAsUnique(); - // Register handle data type - // TODO(tvm-team): consider simply keep type info in the - // type annotation(via a normalizing rewriting). - for (const auto& param : func->params) { - if (auto* ptr = param->type_annotation.as()) { - if (auto* prim = ptr->element_type.as()) { - RegisterHandleType(param.get(), prim->dtype); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); + + this->PrintFuncPrefix(stream); + PrintType(f->ret_type, stream); + this->PrintExtraAttrs(f); + this->stream << " " << static_cast(global_symbol.value()) << "("; + + for (size_t i = 0; i < f->params.size(); ++i) { + tir::Var v = f->params[i]; + std::string vid = AllocVarID(v.get()); + if (i != 0) stream << ", "; + if (v.dtype().is_handle()) { + auto it = alloc_storage_scope_.find(v.get()); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, stream); } - } - } -} -void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { - if (internal_functions_.count(gvar)) { - return; - } + PrintType(GetType(v), stream); + // Register handle data type + // TODO(tvm-team): consider simply keep type info in the + // type annotation(via a normalizing rewriting). + if (auto* ptr = v->type_annotation.as()) { + if (auto* prim = ptr->element_type.as()) { + RegisterHandleType(v.get(), prim->dtype); + } + } - auto function_name = [&]() -> String { - if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { - auto name = global_symbol.value(); - ICHECK(!func_name_supply_->ContainsName(name)) - << "Function " << gvar << " must use global symbol " << name - << ", but this name has already been used."; - func_name_supply_->ReserveName(name); - return name; + if (no_alias) { + PrintRestrict(v, stream); + } } else { - func_name_supply_->ReserveName(gvar->name_hint); - return gvar->name_hint; + PrintType(GetType(v), stream); } - }(); - - internal_functions_.insert({gvar, function_name}); - - InitFuncState(func); - PrintFunctionSignature(function_name, func, fwd_decl_stream); - fwd_decl_stream << ";\n"; -} - -String CodeGenC::GetFunctionName(const GlobalVar& gvar) { - auto it = internal_functions_.find(gvar); - ICHECK(it != internal_functions_.end()) - << "Attempted to find name of " << gvar - << ", but no function with this GlobalVar has been declared"; - return it->second; -} - -void CodeGenC::AddFunction(const GlobalVar& gvar, const PrimFunc& f) { - // If the function has already been forward-declared, this is a - // no-op. - DeclareFunction(gvar, f); - auto function_name = GetFunctionName(gvar); - - // clear previous generated state. - InitFuncState(f); - - PrintFunctionSignature(function_name, f, stream); - stream << " {\n"; + stream << ' ' << vid; + } + stream << ") {\n"; this->PreFunctionBody(f); int func_scope = this->BeginScope(); this->PrintStmt(f->body); @@ -172,15 +130,9 @@ void CodeGenC::AddFunction(const GlobalVar& gvar, const PrimFunc& f) { void CodeGenC::PrintFuncPrefix(std::ostream& os) {} -void CodeGenC::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {} +void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {} -std::string CodeGenC::Finish() { - std::ostringstream code; - code << decl_stream.str(); - code << fwd_decl_stream.str(); - code << stream.str(); - return code.str(); -} +std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) if (print_ssa_form_) { @@ -590,17 +542,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) ICHECK_GE(op->args.size(), 1U); auto func = Downcast(op->args[0]); this->PrintCallExtern(GetType(GetRef(op)), func->value, op->args, true, os); - - // If the call_extern refers to an function within the IRModule, then - // the forward declaration is already provided from DeclareFunction. - if (!func_name_supply_->ContainsName(func->value)) { - Array arg_types; - for (size_t i = 1; i < op->args.size(); i++) { - arg_types.push_back(GetType(op->args[i])); - } - Type ret_type = GetTypeFromRuntimeDataType(op->dtype); - this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type); + Array arg_types; + for (size_t i = 1; i < op->args.size(); i++) { + arg_types.push_back(GetType(op->args[i])); } + Type ret_type = GetTypeFromRuntimeDataType(op->dtype); + this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type); } else if (op_attr_global_symbol_.count(call_op)) { // call extern if the op itself have a global symbol. this->PrintCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], @@ -668,13 +615,9 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else { LOG(FATAL) << "Unresolved call " << op->op; } - } else if (auto opt = op->op.as()) { - auto gvar = opt.value(); - auto callee_name = GetFunctionName(gvar); - PrintCallExtern(GetType(GetRef(op)), callee_name, op->args, false, os); } else { - LOG(FATAL) << "CodeGenC: Unknown operation " << op->op << " is neither a recognized built-in, " - << "nor a GlobalVar reference to another function in the IRModule"; + ICHECK(op->op.as()); + LOG(FATAL) << "Do not yet support cross function call"; } } diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 2921a56ef3a1..93f9ea519c23 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -65,33 +65,12 @@ class CodeGenC : public ExprFunctor, * \param output_ssa Whether output SSA. */ void Init(bool output_ssa); - /*! - * \brief Add the function declaration to the generated module, - * without defining it. - * - * \param gvar The GlobalVar representing the function. - * \param func The function to be compiled. + * \brief Add the function to the generated module. + * \param f The function to be compiled. * \param whether to append return 0 in the end. */ - virtual void DeclareFunction(const GlobalVar& gvar, const PrimFunc& func); - - /*! - * \brief Add the function to the generated module, including its - * declaration and definition. - * - * \param gvar The GlobalVar representing the function. - * \param func The function to be compiled. - */ - virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& func); - - /*! - * \brief Get the name of a declared function - * \param gvar The GlobalVar of the function - * \returns The string name of the function - */ - String GetFunctionName(const GlobalVar& gvar); - + void AddFunction(const PrimFunc& f); /*! * \brief Finalize the compilation and return the code. * \return The code. @@ -117,23 +96,7 @@ class CodeGenC : public ExprFunctor, PrintExpr(n, os); return os.str(); } - // The following parts are overloadable print operations. - - /*! \brief Print the function signature before the argument list - * - * The default implementation delegates out to PrintFuncPrefix and - * PrintExtraAttrs. - * - * \param function_name The name of the function - * - * \param func The function whose signature should be printed - * - * \param os The output stream - */ - virtual void PrintFunctionSignature(const String& function_name, const PrimFunc& func, - std::ostream& os); - /*! * \brief Print the function header before the argument list * \param os The output stream @@ -146,7 +109,7 @@ class CodeGenC : public ExprFunctor, * * Example: __launch_bounds__(256) for CUDA functions */ - virtual void PrintExtraAttrs(const PrimFunc& f, std::ostream& os); // NOLINT(*) + virtual void PrintExtraAttrs(const PrimFunc& f); /*! * \brief Insert statement before function body. * \param f The function to be compiled. @@ -321,24 +284,10 @@ class CodeGenC : public ExprFunctor, private: /*! \brief set of volatile buf access */ std::unordered_set volatile_buf_; - // deep comparison of PrimExpr ExprDeepEqual deep_equal_; - // binding of let variables. Enables duplicate var defs that map to same value std::unordered_map let_binding_; - - /* \brief Map of GlobalVar to their symbol. - * - * For externally-exposed functions, this is given by the - * tvm::attr::kTarget attribute of the PrimFunc. For internal - * functions, this is the name of the function's GlobalVar, possibly - * altered to prevent duplicate names. - */ - std::unordered_map internal_functions_; - - /* \brief Name supply to generate unique function names */ - NameSupply func_name_supply_{""}; }; } // namespace codegen diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index caef43e8af28..3255e11c5d36 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -75,24 +75,19 @@ void CodeGenCHost::InitGlobalContext() { void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } -void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, - bool emit_fwd_func_decl) { - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - if (global_symbol) { - function_names_.push_back(global_symbol.value()); - } +void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) { + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; + function_names_.push_back(global_symbol.value()); emit_fwd_func_decl_ = emit_fwd_func_decl; - CodeGenC::AddFunction(gvar, func); - if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - ICHECK(global_symbol.defined()) - << "CodeGenCHost: The entry func must have the global_symbol attribute, " - << "but function " << gvar << " only has attributes " << func->attrs; - + CodeGenC::AddFunction(f); + if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { function_names_.push_back(runtime::symbol::tvm_module_main); stream << "// CodegenC: NOTE: Auto-generated entry function\n"; PrintFuncPrefix(stream); - PrintType(func->ret_type, stream); + PrintType(f->ret_type, stream); stream << " " << tvm::runtime::symbol::tvm_module_main << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " << "int* out_ret_tcode, void* resource_handle) {\n"; @@ -133,6 +128,15 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*) << "TVM_DLL "; } +std::string CodeGenCHost::Finish() { // NOLINT(*) + std::string ret = decl_stream.str(); + if (emit_fwd_func_decl_) { + ret += fwd_decl_stream.str(); + } + ret += stream.str(); + return ret; +} + void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { @@ -433,38 +437,42 @@ runtime::Module BuildCHost(IRModule mod, Target target) { CodeGenCHost cg; cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); cg.SetConstantsByteAlignment(target->GetAttr("constants-byte-alignment").value_or(16)); - - auto is_aot_executor_fn = [](const PrimFunc& func) -> bool { - return func->GetAttr("runner_function", Bool(false)).value(); - }; - - std::vector> funcs; - for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; - auto prim_func = Downcast(base_func); - funcs.push_back({gvar, prim_func}); + PrimFunc aot_executor_fn; + + std::vector> funcs; + for (auto kv : mod->functions) { + // Make sure that the executor function is the last one to be code generated so that all the + // symbols are available to __tvm_main__ + auto fun_name = std::string(kv.first->name_hint); + bool is_aot_executor_fn = kv.second->GetAttr("runner_function", Bool(false)).value(); + + if (is_aot_executor_fn) { + aot_executor_fn = Downcast(kv.second); + continue; + } + funcs.push_back(kv); } // Sort functions - auto sort_key = [&is_aot_executor_fn](const auto& kv) { - return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint}; - }; - std::sort(funcs.begin(), funcs.end(), [&sort_key](const auto& kv_a, const auto& kv_b) { - return sort_key(kv_a) < sort_key(kv_b); - }); - - // Declare all functions first. This ensures that all functions, - // including the __tvm_main__ used in AOT, have access to forward - // declarations of other functions in the IRModule. - for (const auto& [gvar, prim_func] : funcs) { - cg.DeclareFunction(gvar, prim_func); + std::sort(funcs.begin(), funcs.end(), + [](std::pair kv_a, + std::pair kv_b) { + std::string name_hint_a = kv_a.first->name_hint; + std::string name_hint_b = kv_b.first->name_hint; + return name_hint_a < name_hint_b; + }); + + // Add all functions except __tvm_main__ + for (auto& kv : funcs) { + ICHECK(kv.second->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; + auto f = Downcast(kv.second); + cg.AddFunction(f); } - // Codegen all functions. Passing emit_fwd_func_decl=true adds a - // forward declaration for any `builtin::call_extern`, based on the - // arguments provided to it. - for (const auto& [gvar, prim_func] : funcs) { - cg.AddFunction(gvar, prim_func, emit_fwd_func_decl); + // Add __tvm_main__ + if (aot_executor_fn.defined()) { + emit_fwd_func_decl = true; + cg.AddFunction(aot_executor_fn, emit_fwd_func_decl); } // NOTE: it's possible that kRuntime attr is not attached when the mod was built with tvm.build(). @@ -476,10 +484,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) { } else { runtime = relay::Runtime::Create("cpp", {}); } - - bool has_aot_executor_fn = std::any_of( - funcs.begin(), funcs.end(), [&](const auto& kv) { return is_aot_executor_fn(kv.second); }); - if (has_aot_executor_fn && runtime->name == relay::kTvmRuntimeCpp) { + if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) { cg.InitGlobalContext(); } diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index aeba685f7422..694104afc0af 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -44,7 +44,8 @@ class CodeGenCHost : public CodeGenC { const std::unordered_set& devices); void InitGlobalContext(); - void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool emit_fwd_func_decl = false); + void AddFunction(const PrimFunc& f, bool emit_fwd_func_decl = false); + std::string Finish() final; /*! * \brief Add functions from the (unordered) range to the current module in a deterministic * order. This helps with debugging. diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 7639ce606563..a91f8b016464 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -75,7 +75,7 @@ class ThreadIdxExtractor : public tir::StmtVisitor { PrimExpr threadIdx_z_ext = Integer(1); }; -void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { +void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) { ThreadIdxExtractor extractor; extractor(f->body); arith::Analyzer analyzer; @@ -86,7 +86,7 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { // unable to extract the number of threads per block, hence directly return return; } - os << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; + stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; } } diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index bc7b34b500d8..3ec0c3bc2d9e 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -47,7 +47,7 @@ class CodeGenCUDA final : public CodeGenC { } // override behavior void PrintFuncPrefix(std::ostream& os) final; - void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; // NOLINT(*) + void PrintExtraAttrs(const PrimFunc& f) final; void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 3db8d216b3b1..b8c30691e21f 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -36,8 +36,6 @@ namespace codegen { void CodeGenMetal::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); - // skip the first underscore, so SSA variable starts from _1 - name_supply_->FreshName("v_"); // analyze the data; for (Var arg : f->params) { if (arg.dtype().is_handle()) { @@ -54,33 +52,37 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) { << "};\n\n"; } -void CodeGenMetal::PrintFunctionSignature(const String& function_name, const PrimFunc& func, - std::ostream& os) { +void CodeGenMetal::AddFunction(const PrimFunc& f) { + // clear previous generated state. + this->InitFuncState(f); + // skip the first underscore, so SSA variable starts from _1 + name_supply_->FreshName("v_"); + // add to alloc buffer type. - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. - os << "kernel void " << static_cast(global_symbol.value()) << "("; + this->stream << "kernel void " << static_cast(global_symbol.value()) << "("; // Buffer arguments size_t num_buffer = 0; size_t limit = target_->GetAttr("max_function_args").value().IntValue(); - if (func->params.size() > limit) { + if (f->params.size() > limit) { LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of " "buffers in the kernel"; } - for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { - Var v = func->params[i]; + for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) { + Var v = f->params[i]; if (!v.dtype().is_handle()) break; - os << " "; + stream << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, os); + PrintStorageScope(it->second, stream); } - PrintType(GetType(v), os); + PrintType(GetType(v), stream); // Register handle data type // TODO(tvm-team): consider simply keep type info in the // type annotation(via a normalizing rewriting). @@ -89,18 +91,19 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri RegisterHandleType(v.get(), prim->dtype); } } - os << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; + stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; } // Setup normal arguments. - size_t nargs = func->params.size() - num_buffer; + size_t nargs = f->params.size() - num_buffer; std::string varg = name_supply_->FreshName("arg"); if (nargs != 0) { std::string arg_buf_type = static_cast(global_symbol.value()) + "_args_t"; - os << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n"; + stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer + << ") ]],\n"; // declare the struct decl_stream << "struct " << arg_buf_type << " {\n"; - for (size_t i = num_buffer; i < func->params.size(); ++i) { - Var v = func->params[i]; + for (size_t i = num_buffer; i < f->params.size(); ++i) { + Var v = f->params[i]; ICHECK(!v.dtype().is_handle()); std::string vid = AllocVarID(v.get()); std::ostringstream vref; @@ -128,7 +131,7 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); int work_dim = 0; - auto launch_params = func->GetAttr>(tir::attr::kKernelLaunchParams).value(); + auto launch_params = f->GetAttr>(tir::attr::kKernelLaunchParams).value(); for (const auto& tag : launch_params) { if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) { runtime::ThreadScope scope = runtime::ThreadScope::Create(tag); @@ -147,7 +150,13 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri } thread_work_dim_ = work_dim; - stream << ")"; + // the function scope. + stream << ") {\n"; + int func_scope = this->BeginScope(); + this->PrintStmt(f->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; } void CodeGenMetal::BindThreadIndex(const IterVar& iv) { @@ -333,33 +342,27 @@ runtime::Module BuildMetal(IRModule mod, Target target) { const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile"); std::string fmt = fmetal_compile ? "metallib" : "metal"; - Map functions; - for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; - auto calling_conv = base_func->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - - auto prim_func = Downcast(base_func); - functions.Set(gvar, prim_func); - } + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; + auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()); + std::string func_name = global_symbol.value(); - for (auto [gvar, prim_func] : functions) { - source_maker << "// Function: " << gvar->name_hint << "\n"; + source_maker << "// Function: " << func_name << "\n"; CodeGenMetal cg(target); cg.Init(output_ssa); + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - for (auto [other_gvar, other_prim_func] : functions) { - cg.DeclareFunction(other_gvar, other_prim_func); - } - cg.AddFunction(gvar, prim_func); - + cg.AddFunction(f); std::string fsource = cg.Finish(); source_maker << fsource << "\n"; if (fmetal_compile) { fsource = (*fmetal_compile)(fsource, target).operator std::string(); } - smap[cg.GetFunctionName(gvar)] = fsource; + smap[func_name] = fsource; } return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 26c991e60df9..36be10d16363 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -38,8 +38,7 @@ class CodeGenMetal final : public CodeGenC { explicit CodeGenMetal(Target target); // override print thread tag. void PrintArgUnionDecl(); - void PrintFunctionSignature(const String& function_name, const PrimFunc& func, - std::ostream& os) override; + void AddFunction(const PrimFunc& f); // NOLINT(*) void InitFuncState(const PrimFunc& f) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index da6a4de6196a..c15d2253d716 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -595,26 +595,18 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; - Map functions; - for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; - auto prim_func = Downcast(base_func); - auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - functions.Set(gvar, prim_func); - } - std::stringstream code; const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc"); - for (auto [gvar, prim_func] : functions) { - code << "// Function: " << gvar->name_hint << std::endl; + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; + code << "// Function: " << kv.first->name_hint << std::endl; CodeGenOpenCL cg; cg.Init(output_ssa); - for (auto [other_gvar, other_prim_func] : functions) { - cg.DeclareFunction(other_gvar, other_prim_func); - } - cg.AddFunction(gvar, prim_func); + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + cg.AddFunction(f); std::string fsource = cg.Finish(); if (fpostproc) { fsource = (*fpostproc)(fsource, target).operator std::string(); diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index aa7a32320c5e..83046de10701 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -145,21 +145,13 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) { // Generate source code for get_source(). cg.Init(output_ssa); - Map functions; - for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodeGenVHLS: Can only take PrimFunc"; - auto prim_func = Downcast(base_func); - auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodeGenVHLS: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - functions.Set(gvar, prim_func); - } - - for (auto [gvar, prim_func] : functions) { - cg.DeclareFunction(gvar, prim_func); - } - for (auto [gvar, prim_func] : functions) { - cg.AddFunction(gvar, prim_func); + cg.AddFunction(f); } std::string whole_code = cg.Finish(); @@ -167,21 +159,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) { // Generate source code for compilation. Array> kernel_info; - for (auto [gvar, prim_func] : functions) { + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; + auto f = Downcast(kv.second); CodeGenVivadoHLS cg; cg.Init(output_ssa); - - for (auto [other_gvar, other_prim_func] : functions) { - cg.DeclareFunction(other_gvar, other_prim_func); - } - cg.AddFunction(gvar, prim_func); + cg.AddFunction(f); std::string code = cg.Finish(); if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) { code = (*f)(code, target).operator std::string(); } - auto function_name = cg.GetFunctionName(gvar); - kernel_info.push_back({function_name, code}); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + kernel_info.push_back({global_symbol.value(), code}); } std::string xclbin; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 6a6712a4ce26..4d1d834c7fac 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -45,12 +45,6 @@ std::string CodeGenWebGPU::Finish() { void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); - // skip the first underscore, so SSA variable starts from - name_supply_->FreshName("v_"); - // Setup the thread group info. - ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); - ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); - // analyze the data; for (Var arg : f->params) { if (arg.dtype().is_handle()) { @@ -62,12 +56,28 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {} -void CodeGenWebGPU::PrintFunctionSignature(const String& function_name, const PrimFunc& func, - std::ostream& os) { +void CodeGenWebGPU::AddFunction(const PrimFunc& f) { + // clear previous generated state. + this->InitFuncState(f); + // skip the first underscore, so SSA variable starts from + name_supply_->FreshName("v_"); + // Setup the thread group info. + ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); + ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); + + // add to alloc buffer type. + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; + + decl_stream << "//----------------------------------------\n" + << "// function: " << global_symbol.value() << "\n" + << "//----------------------------------------\n"; + std::vector pod_args; int num_buffer = 0; // setup buffer argumemts - for (Var arg : func->params) { + for (Var arg : f->params) { DataType t = arg.dtype(); if (t.is_handle()) { auto* ptr = arg->type_annotation.as(); @@ -101,18 +111,16 @@ void CodeGenWebGPU::PrintFunctionSignature(const String& function_name, const Pr } // add to alloc buffer type. // Function header. - os << "fn main(\n" - << " @builtin(workgroup_id) blockIdx : vec3,\n" - << " @builtin(local_invocation_id) threadIdx : vec3\n" - << ")"; -} - -void CodeGenWebGPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { - CodeGenC::AddFunction(gvar, func); - decl_stream << "//----------------------------------------\n" - << "// function: " << GetFunctionName(gvar) << "\n" - << "//----------------------------------------\n"; - + this->stream << "fn main(\n" + << " @builtin(workgroup_id) blockIdx : vec3,\n" + << " @builtin(local_invocation_id) threadIdx : vec3\n" + << ") {\n"; + // the function scope. + int func_scope = this->BeginScope(); + this->PrintStmt(f->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; // anotate workgroup this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0] << ", " << workgroup_size_[1] << ", " << workgroup_size_[2] << ")\n"; @@ -516,31 +524,22 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) { mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); bool output_ssa = false; - Map functions; - for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodeGenWebGPU: Can only take PrimFunc"; - auto prim_func = Downcast(base_func); - auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); + std::unordered_map smap; + for (auto kv : mod->functions) { + CodeGenWebGPU cg(target); + ICHECK(kv.second->IsInstance()) << "CodeGenWebGPU: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; - functions.Set(gvar, prim_func); - } - - std::unordered_map smap; - for (auto [gvar, prim_func] : functions) { - CodeGenWebGPU cg(target); + std::string f_name = global_symbol.value(); cg.Init(output_ssa); - - for (auto [other_gvar, other_prim_func] : functions) { - cg.DeclareFunction(other_gvar, other_prim_func); - } - cg.AddFunction(gvar, prim_func); - + cg.AddFunction(f); std::string code = cg.Finish(); - smap[cg.GetFunctionName(gvar)] = code; + smap[f_name] = code; } auto n = make_object(smap, ExtractFuncInfo(mod)); return runtime::Module(n); diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h index 6ae942a3ad49..57f226ba8ad6 100644 --- a/src/target/source/codegen_webgpu.h +++ b/src/target/source/codegen_webgpu.h @@ -48,9 +48,7 @@ class CodeGenWebGPU final : public CodeGenC { explicit CodeGenWebGPU(Target target); // overrides std::string Finish() final; - void PrintFunctionSignature(const String& function_name, const PrimFunc& func, - std::ostream& os) final; - void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final; + void AddFunction(const PrimFunc& f); // NOLINT(*) void InitFuncState(const PrimFunc& f) final; void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 90640a6db647..a6f4b5bb3e5b 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -613,14 +613,12 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } for (const tir::Var& pool_var : metadata_->pools) { - call_args_ss << "((uint8_t*)"; String pool_name = metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name; if (IsInternalWorkspaceBuffer(pool_var)) { - call_args_ss << "&" << pool_name; + call_args_ss << "&" << pool_name << ","; } else { - call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name); + call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << ","; } - call_args_ss << "),"; } for (const String& device : metadata_->devices) { call_args_ss << "devices->" << device << ","; diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index fd14f4892154..39214c4546dc 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -70,32 +70,6 @@ Type GetType(const PrimExpr& expr) { return ptr->type_annotation; } } - - if (auto* access = expr.as()) { - if (access->op.same_as(builtin::tvm_access_ptr())) { - ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have empty arguments"; - auto type_annotation = Downcast(access->args[0]); - static auto builtin_op = Op::Get("tir.type_annotation"); - ICHECK(type_annotation->op.same_as(builtin_op)) - << "Expected the first argument of builtin tvm_access_ptr() " - << "to be a type annotation, but found " << type_annotation->op; - return PointerType(PrimType(type_annotation->dtype)); - } - } - - if (auto* address_of = expr.as()) { - if (address_of->op.same_as(builtin::address_of())) { - ICHECK_EQ(address_of->args.size(), 1) - << "Builtin address_of() expects a single argument, but received arguments " - << address_of->args; - auto* address = address_of->args[0].as(); - ICHECK(address) - << "Builtin address_of() expects the argument to be a BufferLoad, but received argument " - << address_of->args[0]; - - return PointerType(PrimType(address->dtype)); - } - } // Default: return the type indicated by the dtype. runtime::DataType dtype = expr.dtype(); return GetTypeFromRuntimeDataType(dtype); diff --git a/tests/python/relay/aot/test_crt_forward_declarations.py b/tests/python/relay/aot/test_crt_forward_declarations.py index 99e2f0c92300..e001a62ab99a 100644 --- a/tests/python/relay/aot/test_crt_forward_declarations.py +++ b/tests/python/relay/aot/test_crt_forward_declarations.py @@ -33,6 +33,8 @@ AOTTestRunner, ) +pytestmark = pytest.mark.skip(reason="regression introduced in #15725") + def _change_ndarray_layout(arr, src_layout, dst_layout): """Makes a copy of an ndarray, reshaping it to a new data layout. @@ -160,8 +162,8 @@ def test_internal_calls(interface_api, use_unpacked_api, test_runner): lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0] main_source = lib_mod.get_source() - assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 2 - assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 6 + assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 1 + assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 3 @tvm.testing.requires_corstone300 diff --git a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py index f6145cd1c51a..7bea7577b6bf 100644 --- a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py +++ b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py @@ -135,13 +135,8 @@ def test_write_3x3_depthwise_code(): #define TENSORDOT_OPT_X1_INT16_W48_3X3_000_EXISTS #include __attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w48_3x3_000( - int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, - int32_t *bias, int32_t *scale + int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale ) { - int32_t *output = output_arg; - int32_t *tensor = tensor_arg; - int32_t *kernel = kernel_arg; - int32_t sum_0 = *bias; int32_t tensor__y00_x00__y00_x01 = tensor[0]; @@ -193,13 +188,8 @@ def test_odd_width_3x3_depthwise_strides_code(): #define TENSORDOT_OPT_X2_INT16_W49_3X3_000_2_4_EXISTS #include __attribute__((always_inline)) static inline int32_t tensordot_opt_x2_int16_w49_3x3_000_2_4( - int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, - int32_t *bias, int32_t *scale + int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale ) { - int32_t *output = output_arg; - int32_t *tensor = tensor_arg; - int32_t *kernel = kernel_arg; - int32_t sum_0 = *bias, sum_1 = *bias; int32_t tensor__y00_x00__y00_x01 = tensor[0]; @@ -261,13 +251,8 @@ def test_1x1x8_convolution_code(): #define TENSORDOT_OPT_X4_INT16_W384_1X8_000_8_1_EXISTS #include __attribute__((always_inline)) static inline int32_t tensordot_opt_x4_int16_w384_1x8_000_8_1( - int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, - int32_t *bias, int32_t *scale + int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale ) { - int32_t *output = output_arg; - int32_t *tensor = tensor_arg; - int32_t *kernel = kernel_arg; - int32_t sum_0 = *bias, sum_1 = *bias, sum_2 = *bias, sum_3 = *bias; int32_t tensor__y00_x00__y00_x01 = tensor[0]; @@ -364,13 +349,8 @@ def test_3x3x3_offset_convolution_code(): #define TENSORDOT_OPT_X1_INT16_W288_3X9_111_EXISTS #include __attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w288_3x9_111( - int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, - int32_t *bias, int32_t *scale + int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale ) { - int32_t *output = output_arg; - int32_t *tensor = tensor_arg; - int32_t *kernel = kernel_arg; - int32_t sum_0 = *bias; int32_t tensor__unknown__y00_x00 = tensor[0]; diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py index 3aca0fc8c77e..d02f8744f129 100644 --- a/tests/python/unittest/test_target_codegen_c_host.py +++ b/tests/python/unittest/test_target_codegen_c_host.py @@ -14,15 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import tvm import tvm.testing - from tvm import te -from tvm.contrib import utils -from tvm.script import tir as T, ir as I - import numpy as np +from tvm.contrib import utils def test_add(): @@ -232,39 +228,11 @@ def check_global_packed_func(): check_global_packed_func() -def test_subroutine_call(): - @I.ir_module - class mod: - @T.prim_func - def main(A: T.Buffer(1, dtype="float32")): - mod.subroutine(A.data) - - @T.prim_func(private=True) - def subroutine(A_data: T.handle("float32")): - A = T.decl_buffer(1, dtype="float32", data=A_data) - A[0] = 42.0 - - built = tvm.build(mod, target="c") - - func_names = list(built["get_func_names"]()) - assert ( - "main" in func_names - ), "Externally exposed functions should be listed in available functions." - assert ( - "subroutine" not in func_names - ), "Internal function should not be listed in available functions." - - source = built.get_source() - assert ( - source.count("main(void*") == 2 - ), "Expected two occurrences, for forward-declaration and definition" - assert ( - source.count("subroutine(float*") == 2 - ), "Expected two occurrences, for forward-declaration and definition" - assert ( - source.count("subroutine(") == 3 - ), "Expected three occurrences, for forward-declaration, definition, and call from main." - - if __name__ == "__main__": - tvm.testing.main() + test_add() + test_add_pipeline() + test_reinterpret() + test_ceil() + test_floor() + test_round() + test_call_packed() diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 61f0892a9cf3..588a92d87c4b 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -268,7 +268,6 @@ def test_inject_async_copy_barrier(): #define int64_t long long #define uint64_t unsigned long long #endif -extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C); extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { __shared__ float A_shared[64]; __shared__ float B_shared[64];