Skip to content

Commit

Permalink
[LLVM] Remove the "ret_void" argument of AddFunction (#15127)
Browse files Browse the repository at this point in the history
Prior to this commit, the `"ret_void"` argument needed to be
explicitly provided to `CodeGenLLVM::AddFunction` and
`CodeGenLLVM::DeclareFunction`.  If this was inconsistent with the
`builtin::ret()` usage within the `PrimFunc`, this could cause
the incorrect return type in the generated LLVM-IR, resulting in LLVM
IR verification failures.

This commit removes the `"ret_void"` argument, instead using the type
annotation in `PrimFunc::ret_type`, removing this opportunity for
inconsistency.

This PR is intended to fix a ROCm regression reported in
#14901 (comment).
  • Loading branch information
Lunderberg committed Jun 23, 2023
1 parent bee073b commit 7392432
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {

void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(gvar, f, true);
CodeGenLLVM::AddFunctionInternal(gvar, f);
function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
std::ostringstream attr;
attr << "1," << DetectROCMmaxThreadsPerBlock();
Expand Down
22 changes: 11 additions & 11 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,11 @@ void CodeGenLLVM::InitTarget() {
}

llvm::Function* CodeGenLLVM::DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) {
return this->DeclareFunctionInternal(gvar, f, false);
return this->DeclareFunctionInternal(gvar, f);
}

void CodeGenLLVM::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
this->AddFunctionInternal(gvar, f, false);
this->AddFunctionInternal(gvar, f);
}

void CodeGenLLVM::InitFuncState() {
Expand All @@ -258,8 +258,7 @@ std::tuple<std::string, llvm::Function::LinkageTypes> CodeGenLLVM::GetLinkage(
return {symbol_name, llvm::Function::PrivateLinkage};
}

llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func,
bool ret_void) {
llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func) {
if (auto it = functions_.find(gvar.get()); it != functions_.end()) {
return it->second;
}
Expand All @@ -275,11 +274,9 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons
alias_var_set_.insert(param.get());
}
}
// TODO(tvm-team):
// Update the function type to respect the ret_type field of f.
// Once we allow more flexibility in the PrimFunc.

llvm::FunctionType* ftype =
llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);
llvm::FunctionType::get(GetLLVMType(func->ret_type), param_types, false);

auto [symbol_name, linkage_type] = GetLinkage(gvar, func);

Expand All @@ -297,10 +294,10 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons
return function;
}

void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void) {
void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f) {
this->InitFuncState();

function_ = DeclareFunctionInternal(gvar, f, ret_void);
function_ = DeclareFunctionInternal(gvar, f);

// set var map and align information
auto arg_it = function_->arg_begin();
Expand Down Expand Up @@ -341,7 +338,10 @@ void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f,
#endif

EmitDebugLocation(f->span);
if (ret_void) {

if (IsVoidType(f->ret_type)) {
// All other return types are handled when encountering
// builtin::ret().
builder_->CreateRetVoid();
} else {
builder_->CreateRet(ConstInt32(0));
Expand Down
4 changes: 2 additions & 2 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,9 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
std::tuple<std::string, llvm::Function::LinkageTypes> GetLinkage(const GlobalVar& gvar,
const PrimFunc& func);

llvm::Function* DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void);
llvm::Function* DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& f);

void AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void);
void AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f);

// Create extern call
llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name,
Expand Down
4 changes: 2 additions & 2 deletions src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ class CodeGenNVPTX : public CodeGenLLVM {
public:
llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) final {
// add function as void return value
return CodeGenLLVM::DeclareFunctionInternal(gvar, f, true);
return CodeGenLLVM::DeclareFunctionInternal(gvar, f);
}
void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(gvar, f, true);
CodeGenLLVM::AddFunctionInternal(gvar, f);
// annotate as kernel function
llvm::LLVMContext* ctx = llvm_target_->GetContext();
module_->getOrInsertNamedMetadata("nvvm.annotations")
Expand Down

0 comments on commit 7392432

Please sign in to comment.