Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Handle subroutine calls in MakePackedAPI #14913

Merged
merged 3 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
125 changes: 106 additions & 19 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,91 @@ Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
return rewriter(body);
}

class SubroutineCallRewriter : public StmtExprMutator {
public:
static Optional<Stmt> Apply(const Map<GlobalVar, String>& packed_func_methods, Stmt stmt) {
SubroutineCallRewriter rewriter(packed_func_methods);
stmt = rewriter.VisitStmt(std::move(stmt));
if (rewriter.made_change_) {
return stmt;
} else {
return NullOpt;
}
}

private:
explicit SubroutineCallRewriter(const Map<GlobalVar, String>& packed_func_methods)
: packed_func_methods(packed_func_methods) {}

PrimExpr VisitExpr_(const CallNode* op) override {
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));

if (auto* gvar_ptr = node->op.as<GlobalVarNode>()) {
auto gvar = GetRef<GlobalVar>(gvar_ptr);
if (auto symbol = packed_func_methods.Get(gvar)) {
Array<PrimExpr> cpacked_args;
cpacked_args.push_back(tir::StringImm(symbol.value()));
for (auto arg : node->args) {
cpacked_args.push_back(arg);
}

// push an empty handle to be compatible with current cpacked convention
cpacked_args.push_back(tir::make_zero(DataType::Handle()));
made_change_ = true;
return tir::Call(node->dtype, tir::builtin::tvm_call_cpacked(), cpacked_args);
}
}

return node;
}
const Map<GlobalVar, String>& packed_func_methods;
bool made_change_{false};
};

inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}

PrimFunc MakePackedAPI(PrimFunc&& func) {
/* \brief Return the global_symbol of the function, if it should be updated
*
* \param func The function to be inspected
*
* \returns The global_symbol to be used for the function at call
* sites, or NullOpt if the function is to remain unchanged.
*/
Optional<String> RequiresPackedAPI(const PrimFunc& func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
return NullOpt;
}
}

// Internal function calls do not need the PackedFunc API
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
if (!global_symbol.defined()) {
return NullOpt;
}

auto target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
int target_device_type = target.value()->GetTargetDeviceType();
return global_symbol;
}

PrimFunc MakePackedAPI(PrimFunc func) {
auto global_symbol = RequiresPackedAPI(func);
if (!global_symbol.defined()) {
return func;
}
std::string name_hint = global_symbol.value();

Target target = [&]() {
auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget ("
<< tvm::attr::kTarget << "), but the function only has attributes " << func->attrs;
return opt.value();
}();
int target_device_type = target->GetTargetDeviceType();

auto* func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
Expand Down Expand Up @@ -292,39 +363,55 @@ PrimFunc MakePackedAPI(PrimFunc&& func) {
func_ptr->params = args;

Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << global_symbol << " variables " << undefined
ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined
<< " are used, but are not passed in as API arguments";

func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->checked_type_ = func_ptr->func_type_annotation();
func_ptr->ret_type = PrimType(DataType::Int(32));

// return the function.
return std::move(func);
return func;
}

namespace transform {

Pass MakePackedAPI() {
auto pass_func = [](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc>> updates;
auto pass_func = [](IRModule mod, PassContext ctx) {
Map<GlobalVar, String> packed_func_methods;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto prim_func = opt.value();
if (auto global_symbol = RequiresPackedAPI(prim_func)) {
packed_func_methods.Set(gvar, global_symbol.value());
}
}
}

IRModuleNode* mptr = mod.CopyOnWrite();
IRModule updates;

for (const auto& kv : mptr->functions) {
if (auto opt = kv.second.as<PrimFunc>()) {
for (const auto& [gvar, base_func] : mptr->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto func = opt.value();
if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDefault) {
auto updated_func = MakePackedAPI(std::move(func));
updates.push_back({kv.first, updated_func});
auto orig_func = func;

if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) {
func.CopyOnWrite()->body = body.value();
}

func = MakePackedAPI(std::move(func));

if (!func.same_as(orig_func)) {
updates->Add(gvar, func);
}
}
}

for (const auto& pair : updates) {
mptr->AddUnchecked(pair.first, pair.second);
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
return m;
return mod;
};

return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {});
Expand Down
111 changes: 109 additions & 2 deletions tests/python/unittest/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
# specific language governing permissions and limitations
# under the License.

import pytest

import tvm
from tvm import te
import tvm.testing
from tvm import te, tir
from tvm.script import tir as T, ir as I
from tvm.driver.build_module import schedule_to_module


Expand All @@ -39,7 +43,9 @@ def test_makeapi():
)
)(mod)

f = tvm.tir.transform.MakePackedAPI()(mod)["main"]
before = mod
after = tvm.tir.transform.MakePackedAPI()(mod)
f = after["main"]
assert len(f.params) == 6


Expand All @@ -59,6 +65,19 @@ def _find_next(stmt, type):
return stmt


def _find_compute_scope(func):
result = None

def _visitor(stmt):
if isinstance(stmt, tir.AttrStmt) and stmt.attr_key == "compute_scope":
nonlocal result
result = stmt

tir.stmt_functor.post_order_visit(func.body, _visitor)

return result


def test_variable_passed_from_args():
ib = tvm.tir.ir_builder.create()

Expand Down Expand Up @@ -143,5 +162,93 @@ def test_device_api_context_implicit_resource_handle():
assert call_extern.args[2] == device_context_in_resource_handle


@pytest.mark.parametrize("use_global_symbol", [True, False])
def test_no_op_when_global_symbol_is_absent(use_global_symbol):
func_attr = {"target": tvm.target.Target("llvm")}
if use_global_symbol:
func_attr["global_symbol"] = "main"

@T.prim_func
def before():
T.func_attr(func_attr)
T.evaluate(0)

after = tvm.tir.transform.MakePackedAPI()(tvm.IRModule.from_expr(before))["main"]
if use_global_symbol:
assert len(after.params) == 6
else:
tvm.ir.assert_structural_equal(before, after)


def test_internal_subroutine_call():
"""Internal subroutines should not use the PackedFunc API

A subroutine without the "global_symbol" attribute is an internal
subroutine, and is not directly exposed to a user of the generated
`runtime.Module`. Therefore, it doesn't need to follow the
PackedFunc API.
"""

@I.ir_module
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
before.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("llvm")})
T.evaluate(A_data)

after = tvm.tir.transform.MakePackedAPI()(before)
tvm.ir.assert_structural_equal(before["subroutine"], after["subroutine"])

compute_scope = _find_compute_scope(after["main"])
subroutine_call_op = compute_scope.body.value.op
assert isinstance(subroutine_call_op, tvm.ir.GlobalVar), (
f"The main function's CallNode should use the subroutine's GLobalVar as the operation, "
f"but instead has an operation of type {subroutine_call_op}"
)


def test_subroutine_call_to_externally_visible_subroutine():
"""Externally-visible subroutines should use the PackedFunc API

Because the subroutine may be called directly by a user, it must
use the PackedFunc API. Its signature should be updated to the
PackedFunc signature, and call sites should be updated to use
`T.tvm_call_cpacked`.
"""

@I.ir_module
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
before.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.evaluate(A_data)

after = tvm.tir.transform.MakePackedAPI()(before)

main_compute_scope = _find_compute_scope(after["main"])
assert main_compute_scope is not None
subroutine_compute_scope = _find_compute_scope(after["subroutine"])
assert subroutine_compute_scope is not None

subroutine_call_op = main_compute_scope.body.value.op
assert (
isinstance(subroutine_call_op, tvm.ir.Op)
and subroutine_call_op.name == "tir.tvm_call_cpacked"
), (
f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
f"but instead has an operation of type {subroutine_call_op}"
)


if __name__ == "__main__":
test_makeapi()