Skip to content

Commit

Permalink
[TIR] Handle subroutine calls in MakePackedAPI (#14913)
Browse files Browse the repository at this point in the history
* [TIR] MakePackedAPI, handle missing kGlobalSymbol

Previously, `MakePackedAPI` required all functions to have the
`kGlobalSymbol` attribute.  This commit updates the behavior such that
`MakePackedAPI` only modifies PrimFuncs that have the `kGlobalSymbol`
attribute, and passes through any other PrimFunc unmodified.

* [TIR] Update calls to externally-exposed subroutines in MakePackedAPI

When a function is updated to use the `PackedFunc` API, any calls made
to that function from elsewhere in the `IRModule` should be updated as
well.

* Bugfix, don't update the callsite unless the callee is also updated
  • Loading branch information
Lunderberg committed May 25, 2023
1 parent bcf7abb commit 41a616f
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 21 deletions.
125 changes: 106 additions & 19 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,91 @@ Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
return rewriter(body);
}

class SubroutineCallRewriter : public StmtExprMutator {
public:
static Optional<Stmt> Apply(const Map<GlobalVar, String>& packed_func_methods, Stmt stmt) {
SubroutineCallRewriter rewriter(packed_func_methods);
stmt = rewriter.VisitStmt(std::move(stmt));
if (rewriter.made_change_) {
return stmt;
} else {
return NullOpt;
}
}

private:
explicit SubroutineCallRewriter(const Map<GlobalVar, String>& packed_func_methods)
: packed_func_methods(packed_func_methods) {}

PrimExpr VisitExpr_(const CallNode* op) override {
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));

if (auto* gvar_ptr = node->op.as<GlobalVarNode>()) {
auto gvar = GetRef<GlobalVar>(gvar_ptr);
if (auto symbol = packed_func_methods.Get(gvar)) {
Array<PrimExpr> cpacked_args;
cpacked_args.push_back(tir::StringImm(symbol.value()));
for (auto arg : node->args) {
cpacked_args.push_back(arg);
}

// push an empty handle to be compatible with current cpacked convention
cpacked_args.push_back(tir::make_zero(DataType::Handle()));
made_change_ = true;
return tir::Call(node->dtype, tir::builtin::tvm_call_cpacked(), cpacked_args);
}
}

return node;
}
const Map<GlobalVar, String>& packed_func_methods;
bool made_change_{false};
};

inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}

PrimFunc MakePackedAPI(PrimFunc&& func) {
/* \brief Return the global_symbol of the function, if it should be updated
*
* \param func The function to be inspected
*
* \returns The global_symbol to be used for the function at call
* sites, or NullOpt if the function is to remain unchanged.
*/
Optional<String> RequiresPackedAPI(const PrimFunc& func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
return NullOpt;
}
}

// Internal function calls do not need the PackedFunc API
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
if (!global_symbol.defined()) {
return NullOpt;
}

auto target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
int target_device_type = target.value()->GetTargetDeviceType();
return global_symbol;
}

PrimFunc MakePackedAPI(PrimFunc func) {
auto global_symbol = RequiresPackedAPI(func);
if (!global_symbol.defined()) {
return func;
}
std::string name_hint = global_symbol.value();

Target target = [&]() {
auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget ("
<< tvm::attr::kTarget << "), but the function only has attributes " << func->attrs;
return opt.value();
}();
int target_device_type = target->GetTargetDeviceType();

auto* func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
Expand Down Expand Up @@ -292,39 +363,55 @@ PrimFunc MakePackedAPI(PrimFunc&& func) {
func_ptr->params = args;

Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << global_symbol << " variables " << undefined
ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined
<< " are used, but are not passed in as API arguments";

func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->checked_type_ = func_ptr->func_type_annotation();
func_ptr->ret_type = PrimType(DataType::Int(32));

// return the function.
return std::move(func);
return func;
}

namespace transform {

Pass MakePackedAPI() {
auto pass_func = [](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc>> updates;
auto pass_func = [](IRModule mod, PassContext ctx) {
Map<GlobalVar, String> packed_func_methods;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto prim_func = opt.value();
if (auto global_symbol = RequiresPackedAPI(prim_func)) {
packed_func_methods.Set(gvar, global_symbol.value());
}
}
}

IRModuleNode* mptr = mod.CopyOnWrite();
IRModule updates;

for (const auto& kv : mptr->functions) {
if (auto opt = kv.second.as<PrimFunc>()) {
for (const auto& [gvar, base_func] : mptr->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto func = opt.value();
if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDefault) {
auto updated_func = MakePackedAPI(std::move(func));
updates.push_back({kv.first, updated_func});
auto orig_func = func;

if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) {
func.CopyOnWrite()->body = body.value();
}

func = MakePackedAPI(std::move(func));

if (!func.same_as(orig_func)) {
updates->Add(gvar, func);
}
}
}

for (const auto& pair : updates) {
mptr->AddUnchecked(pair.first, pair.second);
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
return m;
return mod;
};

return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {});
Expand Down
111 changes: 109 additions & 2 deletions tests/python/unittest/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
# specific language governing permissions and limitations
# under the License.

import pytest

import tvm
from tvm import te
import tvm.testing
from tvm import te, tir
from tvm.script import tir as T, ir as I
from tvm.driver.build_module import schedule_to_module


Expand All @@ -39,7 +43,9 @@ def test_makeapi():
)
)(mod)

f = tvm.tir.transform.MakePackedAPI()(mod)["main"]
before = mod
after = tvm.tir.transform.MakePackedAPI()(mod)
f = after["main"]
assert len(f.params) == 6


Expand All @@ -59,6 +65,19 @@ def _find_next(stmt, type):
return stmt


def _find_compute_scope(func):
result = None

def _visitor(stmt):
if isinstance(stmt, tir.AttrStmt) and stmt.attr_key == "compute_scope":
nonlocal result
result = stmt

tir.stmt_functor.post_order_visit(func.body, _visitor)

return result


def test_variable_passed_from_args():
ib = tvm.tir.ir_builder.create()

Expand Down Expand Up @@ -143,5 +162,93 @@ def test_device_api_context_implicit_resource_handle():
assert call_extern.args[2] == device_context_in_resource_handle


@pytest.mark.parametrize("use_global_symbol", [True, False])
def test_no_op_when_global_symbol_is_absent(use_global_symbol):
func_attr = {"target": tvm.target.Target("llvm")}
if use_global_symbol:
func_attr["global_symbol"] = "main"

@T.prim_func
def before():
T.func_attr(func_attr)
T.evaluate(0)

after = tvm.tir.transform.MakePackedAPI()(tvm.IRModule.from_expr(before))["main"]
if use_global_symbol:
assert len(after.params) == 6
else:
tvm.ir.assert_structural_equal(before, after)


def test_internal_subroutine_call():
"""Internal subroutines should not use the PackedFunc API
A subroutine without the "global_symbol" attribute is an internal
subroutine, and is not directly exposed to a user of the generated
`runtime.Module`. Therefore, it doesn't need to follow the
PackedFunc API.
"""

@I.ir_module
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
before.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("llvm")})
T.evaluate(A_data)

after = tvm.tir.transform.MakePackedAPI()(before)
tvm.ir.assert_structural_equal(before["subroutine"], after["subroutine"])

compute_scope = _find_compute_scope(after["main"])
subroutine_call_op = compute_scope.body.value.op
assert isinstance(subroutine_call_op, tvm.ir.GlobalVar), (
f"The main function's CallNode should use the subroutine's GLobalVar as the operation, "
f"but instead has an operation of type {subroutine_call_op}"
)


def test_subroutine_call_to_externally_visible_subroutine():
"""Externally-visible subroutines should use the PackedFunc API
Because the subroutine may be called directly by a user, it must
use the PackedFunc API. Its signature should be updated to the
PackedFunc signature, and call sites should be updated to use
`T.tvm_call_cpacked`.
"""

@I.ir_module
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
before.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.evaluate(A_data)

after = tvm.tir.transform.MakePackedAPI()(before)

main_compute_scope = _find_compute_scope(after["main"])
assert main_compute_scope is not None
subroutine_compute_scope = _find_compute_scope(after["subroutine"])
assert subroutine_compute_scope is not None

subroutine_call_op = main_compute_scope.body.value.op
assert (
isinstance(subroutine_call_op, tvm.ir.Op)
and subroutine_call_op.name == "tir.tvm_call_cpacked"
), (
f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
f"but instead has an operation of type {subroutine_call_op}"
)


if __name__ == "__main__":
test_makeapi()

0 comments on commit 41a616f

Please sign in to comment.