diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 6bf1ca6eabd5..d95f985fe63f 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -89,7 +89,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final { // add function as void return value - CodeGenLLVM::AddFunctionInternal(gvar, f, true); + CodeGenLLVM::AddFunctionInternal(gvar, f); function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); std::ostringstream attr; attr << "1," << DetectROCMmaxThreadsPerBlock(); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index ada51677ac16..67c81d2803b6 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -227,11 +227,11 @@ void CodeGenLLVM::InitTarget() { } llvm::Function* CodeGenLLVM::DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) { - return this->DeclareFunctionInternal(gvar, f, false); + return this->DeclareFunctionInternal(gvar, f); } void CodeGenLLVM::AddFunction(const GlobalVar& gvar, const PrimFunc& f) { - this->AddFunctionInternal(gvar, f, false); + this->AddFunctionInternal(gvar, f); } void CodeGenLLVM::InitFuncState() { @@ -258,8 +258,7 @@ std::tuple CodeGenLLVM::GetLinkage( return {symbol_name, llvm::Function::PrivateLinkage}; } -llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func, - bool ret_void) { +llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func) { if (auto it = functions_.find(gvar.get()); it != functions_.end()) { return it->second; } @@ -275,11 +274,9 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons alias_var_set_.insert(param.get()); } } - // TODO(tvm-team): - // Update the function type to respect the ret_type field of f. - // Once we allow more flexibility in the PrimFunc. + llvm::FunctionType* ftype = - llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false); + llvm::FunctionType::get(GetLLVMType(func->ret_type), param_types, false); auto [symbol_name, linkage_type] = GetLinkage(gvar, func); @@ -297,10 +294,10 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons return function; } -void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void) { +void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f) { this->InitFuncState(); - function_ = DeclareFunctionInternal(gvar, f, ret_void); + function_ = DeclareFunctionInternal(gvar, f); // set var map and align information auto arg_it = function_->arg_begin(); @@ -341,7 +338,10 @@ void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, #endif EmitDebugLocation(f->span); - if (ret_void) { + + if (IsVoidType(f->ret_type)) { + // All other return types are handled when encountering + // builtin::ret(). builder_->CreateRetVoid(); } else { builder_->CreateRet(ConstInt32(0)); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 07d96e40e07c..8c8929c8f093 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -381,9 +381,9 @@ class CodeGenLLVM : public ExprFunctor, std::tuple GetLinkage(const GlobalVar& gvar, const PrimFunc& func); - llvm::Function* DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void); + llvm::Function* DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& f); - void AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void); + void AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f); // Create extern call llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name, diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index a40208513079..b6c19c92d70f 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -68,11 +68,11 @@ class CodeGenNVPTX : public CodeGenLLVM { public: llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) final { // add function as void return value - return CodeGenLLVM::DeclareFunctionInternal(gvar, f, true); + return CodeGenLLVM::DeclareFunctionInternal(gvar, f); } void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final { // add function as void return value - CodeGenLLVM::AddFunctionInternal(gvar, f, true); + CodeGenLLVM::AddFunctionInternal(gvar, f); // annotate as kernel function llvm::LLVMContext* ctx = llvm_target_->GetContext(); module_->getOrInsertNamedMetadata("nvvm.annotations")