Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM] Replace calls to Type::getVectorNumElements #5398

Merged
merged 1 commit into from
Apr 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
}

llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
if (extent == num_elems && begin == 0) return vec;
CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
std::vector<llvm::Constant*> indices;
Expand All @@ -490,7 +490,7 @@ llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent
}

llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
#if TVM_LLVM_VERSION >= 110
std::vector<int> indices;
#else
Expand All @@ -505,7 +505,7 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
llvm::Value* mask = llvm::UndefValue::get(
DTypeToLLVMType(DataType::Int(32, target_lanes)));
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
if (num_elems == target_lanes) return vec;
CHECK_LT(num_elems, target_lanes);
for (int i = 0; i < num_elems; ++i) {
Expand All @@ -519,16 +519,15 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
int total_lanes = 0;

for (llvm::Value* v : vecs) {
total_lanes += static_cast<int>(
v->getType()->getVectorNumElements());
total_lanes += llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
}
while (vecs.size() > 1) {
std::vector<llvm::Value*> new_vecs;
for (size_t i = 0; i < vecs.size() - 1; i += 2) {
llvm::Value* lhs = vecs[i];
llvm::Value* rhs = vecs[i + 1];
const size_t lhs_lanes = lhs->getType()->getVectorNumElements();
const size_t rhs_lanes = rhs->getType()->getVectorNumElements();
const size_t lhs_lanes = llvm::cast<llvm::VectorType>(lhs->getType())->getNumElements();
const size_t rhs_lanes = llvm::cast<llvm::VectorType>(rhs->getType())->getNumElements();
if (lhs_lanes < rhs_lanes) {
lhs = CreateVecPad(lhs, rhs_lanes);
} else if (rhs_lanes < lhs_lanes) {
Expand Down Expand Up @@ -870,16 +869,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
return builder_->CreateFCmpUNO(a, a);
} else if (op->is_intrinsic("vectorlow")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
return CreateVecSlice(v, 0, l/2);
} else if (op->is_intrinsic("vectorhigh")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
return CreateVecSlice(v, l/2, l/2);
} else if (op->is_intrinsic("vectorcombine")) {
llvm::Value *v0 = MakeValue(op->args[0]);
llvm::Value *v1 = MakeValue(op->args[1]);
int num_elems = static_cast<int>(v0->getType()->getVectorNumElements()) * 2;
int num_elems = llvm::cast<llvm::VectorType>(v0->getType())->getNumElements() * 2;
#if TVM_LLVM_VERSION >= 110
std::vector<int> indices;
#else
Expand Down
13 changes: 6 additions & 7 deletions src/target/llvm/codegen_x86_64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,20 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr

const std::vector<llvm::Value*>& args) {
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
if (intrin_lanes == result_ty->getVectorNumElements()) {
size_t num_elems = llvm::cast<llvm::VectorType>(result_ty)->getNumElements();
if (intrin_lanes == num_elems) {
return builder_->CreateCall(f, args);
}

// Otherwise, we split the vector into intrin_lanes sized elements (widening where necessary),
// compute each result, and then concatenate the vectors (slicing the result if necessary).
CHECK_LT(intrin_lanes, result_ty->getVectorNumElements());
CHECK_LT(intrin_lanes, num_elems);
std::vector<llvm::Value*> split_results;
for (size_t i = 0;
i < static_cast<size_t>(result_ty->getVectorNumElements());
i += intrin_lanes) {
for (size_t i = 0; i < num_elems; i += intrin_lanes) {
std::vector<llvm::Value*> split_args;
for (const auto& v : args) {
if (v->getType()->isVectorTy()) {
CHECK_EQ(v->getType()->getVectorNumElements(), result_ty->getVectorNumElements());
CHECK_EQ(llvm::cast<llvm::VectorType>(v->getType())->getNumElements(), num_elems);
split_args.push_back(CreateVecSlice(v, i, intrin_lanes));
} else {
split_args.push_back(v);
Expand All @@ -147,7 +146,7 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr
id, intrin_lanes, llvm::VectorType::get(result_ty->getScalarType(), intrin_lanes),
split_args));
}
return CreateVecSlice(CreateVecConcat(split_results), 0, result_ty->getVectorNumElements());
return CreateVecSlice(CreateVecConcat(split_results), 0, num_elems);
}

TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64")
Expand Down