Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM] Fix CodeGenLLVM::LinkParameters #8213

Merged
merged 1 commit into from
Jun 10, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 29 additions & 47 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm,
md_builder_.reset(new llvm::MDBuilder(*ctx_));
// types
t_void_ = llvm::Type::getVoidTy(*ctx_);
t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo();
t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(GetGlobalAddressSpace());
t_int_ = llvm::Type::getInt32Ty(*ctx_);
t_char_ = llvm::Type::getInt8Ty(*ctx_);
t_int8_ = llvm::Type::getInt8Ty(*ctx_);
Expand Down Expand Up @@ -191,20 +191,10 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
// It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc,
// but they are at a different layer in the compiler...
std::vector<llvm::Type*> param_types;
// args
param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
// tcodes
param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
// num_args
param_types.push_back(t_int_);
// ret_args
param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
// ret_tcodes
param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
// resource_handle
param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
llvm::Type* t_int_p = t_int_->getPointerTo(GetGlobalAddressSpace());

// args, tcodes, num_args, ret_value, ret_tcode, resource_handle
std::vector<llvm::Type*> param_types{t_void_p_, t_int_p, t_int_, t_void_p_, t_int_p, t_void_p_};
llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false);

llvm::Function* function =
Expand All @@ -215,41 +205,29 @@ void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {

llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function);
builder_->SetInsertPoint(entry);
std::vector<llvm::Value*> zero_index_list{llvm::ConstantInt::get(t_int32_, 0)};
std::vector<llvm::Value*> zero_array_index_list{llvm::ConstantInt::get(t_int32_, 0),
llvm::ConstantInt::get(t_int32_, 0)};
auto args_array = builder_->CreateBitCast(
#if TVM_LLVM_VERSION >= 50
&function->arg_begin()[0],

auto getArg = [function](int i) -> llvm::Argument* {
#if TVM_LLVM_VERSION >= 100
return function->getArg(i);
#elif TVM_LLVM_VERSION >= 50
return &function->arg_begin()[i];
#else
&(*(function->arg_begin())),
return &*std::next(function->arg_begin(), i);
#endif
llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1));
llvm::Value* sid = builder_->CreateBitCast(
builder_->CreateLoad(t_void_->getPointerTo(GetGlobalAddressSpace()),
builder_->CreateInBoundsGEP(args_array, zero_index_list)),
t_int64_);
};

llvm::Type* t_int64_p = t_int64_->getPointerTo(GetGlobalAddressSpace());
llvm::Value* sid = builder_->CreateLoad(t_int64_, builder_->CreateBitCast(getArg(0), t_int64_p));

auto ret_tcode = builder_->CreateBitCast(getArg(4), t_int_p);
auto ret_value =
builder_->CreateBitCast(getArg(3), t_void_p_->getPointerTo(GetGlobalAddressSpace()));

llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function);
auto ret_types_array = builder_->CreateBitCast(
#if TVM_LLVM_VERSION >= 50
&function->arg_begin()[4],
#else
&(*(std::next(function->arg_begin(), 4))),
#endif
llvm::ArrayType::get(t_int_, 1)->getPointerTo());
auto retval_array = builder_->CreateBitCast(
#if TVM_LLVM_VERSION >= 50
&function->arg_begin()[3],
#else
&(*std::next(function->arg_begin(), 3)),
#endif
llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1)->getPointerTo());
llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1);

builder_->SetInsertPoint(default_block);
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr),
builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list));
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr), ret_tcode);
areusch marked this conversation as resolved.
Show resolved Hide resolved
builder_->CreateRet(ConstInt32(kTvmErrorNoError));

// Add data to the global section.
Expand All @@ -258,16 +236,20 @@ void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first;
llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable(
*module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name);
auto dtype = tvm::runtime::DataType(kv.second->param->dtype);
size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment);
#if TVM_LLVM_VERSION >= 100
param_symbol->setAlignment(llvm::Align(align));
#else
param_symbol->setAlignment(align);
#endif

llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function);
switch_inst->addCase(
llvm::cast<llvm::ConstantInt>(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block);
builder_->SetInsertPoint(case_block);
builder_->CreateStore(
builder_->CreatePointerCast(param_symbol, t_void_->getPointerTo(GetGlobalAddressSpace())),
builder_->CreateInBoundsGEP(retval_array, zero_array_index_list));
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle),
builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list));
builder_->CreateStore(builder_->CreatePointerCast(param_symbol, t_void_p_), ret_value);
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode);
builder_->CreateRet(ConstInt32(0));
}
}
Expand Down