Skip to content

Commit

Permalink
[TVMScript][TIR] Pretty print TIR LLVM function name
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 committed Oct 23, 2023
1 parent 6a8cb32 commit 691822c
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 9 deletions.
13 changes: 6 additions & 7 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
# pylint: disable=redefined-builtin, invalid-name
"""Operators used in TIR expression."""
import warnings
from typing import Any, Optional

import tvm._ffi
Expand Down Expand Up @@ -251,7 +250,7 @@ def call_llvm_intrin(dtype, name, *args, span=None):
The name of the llvm intrinsic function.
args : list
Poistional arguments.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Expand All @@ -271,11 +270,11 @@ def call_llvm_intrin(dtype, name, *args, span=None):
else:
llvm_id = name
if llvm_id == 0:
warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
raise ValueError(f"Unknown llvm intrinsic function {name}")
return call_intrin(
dtype,
Op.get("tir.call_llvm_intrin"),
tvm.tir.const(llvm_id, "uint32"),
codegen.llvm_get_intrinsic_name(llvm_id),
*args,
span=span,
)
Expand All @@ -293,7 +292,7 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
The name of the llvm intrinsic function.
args : list
Poistional arguments.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Expand All @@ -313,11 +312,11 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
else:
llvm_id = name
if llvm_id == 0:
warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
raise ValueError(f"Unknown llvm intrinsic function {name}")
return call_intrin(
dtype,
Op.get("tir.call_llvm_pure_intrin"),
tvm.tir.const(llvm_id, "uint32"),
codegen.llvm_get_intrinsic_name(llvm_id),
*args,
span=span,
)
Expand Down
8 changes: 7 additions & 1 deletion src/target/llvm/codegen_arm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ class CodeGenARM final : public CodeGenCPU {

llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
llvm::Intrinsic::ID id = 0;
if (op->args[0]->IsInstance<StringImmNode>()) {
id = llvm::Function::lookupIntrinsicID(Downcast<StringImm>(op->args[0])->value.c_str());
} else if (op->args[0]->IsInstance<IntImmNode>()) {
id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
}
ICHECK(id != 0) << "Invalid LLVM intrinsic";
if (id == llvm::Intrinsic::ctpop) {
PrimExpr e = ARMPopcount(op);
return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
Expand Down
7 changes: 6 additions & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,12 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) {
llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
ICHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
llvm::Intrinsic::ID id = 0;
if (op->args[0]->IsInstance<StringImmNode>()) {
id = llvm::Function::lookupIntrinsicID(Downcast<StringImm>(op->args[0])->value.c_str());
} else if (op->args[0]->IsInstance<IntImmNode>()) {
id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
}
int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_tir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,12 @@ def test_comm_reducer(num_args):
assert tvm.tir.max(*range(num_args)) == num_args - 1


def test_llvm_intrin():
with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function dummy_invalid"):
a = tvm.tir.call_llvm_intrin("int32x4", "dummy_invalid", 0)
with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function dummy_invalid"):
a = tvm.tir.call_llvm_pure_intrin("int32x4", "dummy_invalid", 0)


if __name__ == "__main__":
tvm.testing.main()
7 changes: 7 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,13 @@ def test_cast():
)


def test_llvm_intrin_imm():
a = tir.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0))
_assert_print(a, 'T.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0))')
a = tir.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0))
_assert_print(a, 'T.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0))')


def test_binary_arith():
a = tir.Var("a", "int32")
b = tir.Var("b", "int32")
Expand Down

0 comments on commit 691822c

Please sign in to comment.