Skip to content

Commit

Permalink
[TVMScript][TIR] Pretty print TIR LLVM function name
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 committed Oct 24, 2023
1 parent 6a8cb32 commit af63424
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
9 changes: 4 additions & 5 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
# pylint: disable=redefined-builtin, invalid-name
"""Operators used in TIR expression."""
import warnings
from typing import Any, Optional

import tvm._ffi
Expand Down Expand Up @@ -251,7 +250,7 @@ def call_llvm_intrin(dtype, name, *args, span=None):
The name of the llvm intrinsic function.
args : list
Poistional arguments.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Expand All @@ -271,7 +270,7 @@ def call_llvm_intrin(dtype, name, *args, span=None):
else:
llvm_id = name
if llvm_id == 0:
warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
raise ValueError(f"Unknown llvm intrinsic function {name}")
return call_intrin(
dtype,
Op.get("tir.call_llvm_intrin"),
Expand All @@ -293,7 +292,7 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
The name of the llvm intrinsic function.
args : list
Poistional arguments.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Expand All @@ -313,7 +312,7 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
else:
llvm_id = name
if llvm_id == 0:
warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
raise ValueError(f"Unknown llvm intrinsic function {name}")
return call_intrin(
dtype,
Op.get("tir.call_llvm_pure_intrin"),
Expand Down
25 changes: 25 additions & 0 deletions src/script/printer/tir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,31 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
dtype_print_location =
static_cast<tir::ScriptDtypePrintLocation>(dtype_locations[op].IntValue());
}
if (name == "call_llvm_pure_intrin" || name == "call_llvm_intrin") {
int n_args = call->args.size();
int64_t id = call->args[0].as<IntImmNode>()->value;
auto f_llvm_lookup_intrinsic_name =
tvm::runtime::Registry::Get("target.llvm_get_intrinsic_name");

Array<ExprDoc> args;
args.reserve(n_args + 1);
if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) {
args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype")));
}

for (int i = 0; i < n_args; ++i) {
if ((i == 0) && (f_llvm_lookup_intrinsic_name)) {
String name = (*f_llvm_lookup_intrinsic_name)(id);
args.push_back(LiteralDoc::Str(name.c_str(), call_p->Attr("args")->ArrayIndex(i)));
} else {
args.push_back(d->AsDoc<ExprDoc>(call->args[i], call_p->Attr("args")->ArrayIndex(i)));
}
}
if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) {
args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype")));
}
return prefix->Call(args);
}
} else if (call->op.as<GlobalVarNode>()) {
prefix = d->AsDoc<ExprDoc>(call->op, call_p->Attr("op"));
} else {
Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_tir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,12 @@ def test_comm_reducer(num_args):
assert tvm.tir.max(*range(num_args)) == num_args - 1


def test_llvm_intrin():
with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function llvm.dummy"):
a = tvm.tir.call_llvm_intrin("int32x4", "llvm.dummy", 0)
with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function llvm.dummy"):
a = tvm.tir.call_llvm_pure_intrin("int32x4", "llvm.dummy", 0)


if __name__ == "__main__":
tvm.testing.main()
7 changes: 7 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,13 @@ def test_cast():
)


def test_llvm_intrin_imm():
a = tir.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0))
_assert_print(a, 'T.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0))')
a = tir.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0))
_assert_print(a, 'T.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0))')


def test_binary_arith():
a = tir.Var("a", "int32")
b = tir.Var("b", "int32")
Expand Down

0 comments on commit af63424

Please sign in to comment.