Skip to content

Commit

Permalink
[TIR] Handle subroutine calls in MakeUnpackedAPI
Browse files Browse the repository at this point in the history
Prior to this commit, MakeUnpackedAPI required all functions to be
annotated with `kGlobalSymbol` (`"global_symbol"`).  This commit
updates the transformation to apply only to functions that have the
`kGlobalSymbol` attribute, and to update any internal callsites of the
modified functions.

This is analogous to the changes made in
apache#14913, which updates
`MakePackedAPI`.
  • Loading branch information
Lunderberg committed May 22, 2023
1 parent cd45513 commit ed4c9f9
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 25 deletions.
109 changes: 90 additions & 19 deletions src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,79 @@
namespace tvm {
namespace tir {

PrimFunc MakeUnpackedAPI(PrimFunc&& func) {
class SubroutineCallRewriter : public StmtExprMutator {
public:
static Optional<Stmt> Apply(const std::unordered_set<const GlobalVarNode*>& external_methods,
Stmt stmt) {
SubroutineCallRewriter rewriter(external_methods);
stmt = rewriter.VisitStmt(std::move(stmt));
if (rewriter.made_change_) {
return stmt;
} else {
return NullOpt;
}
}

private:
explicit SubroutineCallRewriter(const std::unordered_set<const GlobalVarNode*>& external_methods)
: external_methods_(external_methods) {}

PrimExpr VisitExpr_(const CallNode* op) override {
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));

if (auto gvar = node->op.as<GlobalVarNode>()) {
if (external_methods_.count(gvar)) {
Array<PrimExpr> args = node->args.Map([this](const PrimExpr& arg) -> PrimExpr {
if (auto* as_call = arg.as<CallNode>()) {
if (as_call->op.same_as(builtin::tvm_stack_make_array())) {
PrimExpr data_ptr = as_call->args[0];
made_change_ = true;
return data_ptr;
}
}
return arg;
});
if (!args.same_as(node->args)) {
node.CopyOnWrite()->args = args;
}
}
}

return std::move(node);
}
const std::unordered_set<const GlobalVarNode*>& external_methods_;
bool made_change_{false};
};

PrimFunc MakeUnpackedAPI(PrimFunc func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
return func;
}
}

// Internal function calls do not need API updates
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol) << "MakeUnpackedAPI: Expect PrimFunc to have the global_symbol attribute";
if (!global_symbol.defined()) {
return func;
}

auto target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "MakeUnpackedAPI: Require the target attribute";
Target target = [&]() {
auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt) << "MakeUnpackedAPI required the function to be annotated with tvm::attr::kTarget ("
<< tvm::attr::kTarget << "), but the function only has attributes " << func->attrs;
return opt.value();
}();
int target_device_type = target->GetTargetDeviceType();

auto* func_ptr = func.CopyOnWrite();

// Setup device context
int target_device_type = target.value()->GetTargetDeviceType();
Integer device_type(target_device_type);
Integer device_id(0);
PrimExpr node = StringImm("default");
ObjectRef node = String("default");
const Stmt nop = Evaluate(0);
std::vector<Stmt> device_init;

Expand Down Expand Up @@ -82,31 +141,43 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) {
func_ptr->buffer_map = Map<Var, Buffer>();

// return the function.
return std::move(func);
return func;
}

namespace transform {

Pass MakeUnpackedAPI() {
auto pass_func = [](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc>> updates;
auto pass_func = [](IRModule mod, PassContext ctx) {
std::unordered_set<const GlobalVarNode*> external_methods;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto* prim_func = base_func.as<PrimFuncNode>()) {
if (prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
external_methods.insert(gvar.get());
}
}
}

IRModule updates;

for (const auto& kv : mptr->functions) {
if (auto opt = kv.second.as<PrimFunc>()) {
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto func = opt.value();
if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDefault) {
auto updated_func = MakeUnpackedAPI(std::move(func));
updates.push_back({kv.first, updated_func});

if (auto body = SubroutineCallRewriter::Apply(external_methods, func->body)) {
func.CopyOnWrite()->body = body.value();
}

func = MakeUnpackedAPI(std::move(func));
if (!func.same_as(base_func)) {
updates->Add(gvar, func);
}
}
}

for (const auto& pair : updates) {
mptr->AddUnchecked(pair.first, pair.second);
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
return m;
return mod;
};

return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakeUnpackedAPI", {});
Expand Down
158 changes: 152 additions & 6 deletions tests/python/unittest/test_tir_transform_make_unpacked_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import pytest

import tvm
from tvm import te
from tvm import te, tir
from tvm.script import tir as T, ir as I
import numpy


Expand All @@ -39,17 +40,20 @@ def mod(mod_without_attrs):
return mod


def test_fails_if_not_global_symbol(mod_without_attrs):
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(
def test_noop_if_not_global_symbol(mod_without_attrs):
before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(
mod_without_attrs
)
with pytest.raises(tvm.TVMError, match="Expect PrimFunc to have the global_symbol attribute"):
f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
after = tvm.tir.transform.MakeUnpackedAPI()(before)
tvm.ir.assert_structural_equal(before, after)


def test_fails_if_no_target(mod_without_attrs):
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod_without_attrs)
with pytest.raises(tvm.TVMError, match="Require the target attribute"):
with pytest.raises(
tvm.TVMError,
match="MakeUnpackedAPI required the function to be annotated with tvm::attr::kTarget",
):
f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]


Expand Down Expand Up @@ -134,5 +138,147 @@ def test_body():
assert f.params[2].name == "A"


class TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter):
"""Internal subroutines do not require modification
A subroutine without the "global_symbol" attribute is an internal
subroutine, and is not directly exposed to a user of the generated
`runtime.Module`.
"""

transform = tvm.tir.transform.MakeUnpackedAPI()

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
mod.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("llvm")})
T.evaluate(A_data)

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("llvm")})
T.evaluate(A_data)

return mod


class TestSubroutineCallToExternallyVisibleSubroutine(tvm.testing.CompareBeforeAfter):
"""Externally-visible subroutines should be updated
Subroutines that are exposed externally should be updated by
MakeUnpackedAPI.
"""

transform = tvm.tir.transform.MakeUnpackedAPI()

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
mod.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.evaluate(A_data)

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)

@T.prim_func
def subroutine(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.evaluate(A_data)

return mod


class TestCallExternallyVisibleSubroutineWithDLTensor(tvm.testing.CompareBeforeAfter):
"""Callsites of externally-visible subroutines may require updates
The MakeUnpackedAPI transform lowers all buffers into a data
pointer to a primitive type. If a subroutine call is currently
passing a DLTensor produced by `T.tvm_make_stack_array` into the
subroutine, the callsite should be updated to instead pass the
data pointer directly.
"""

transform = tvm.tir.transform.MakeUnpackedAPI()

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
mod.subroutine(
T.tvm_stack_make_array(
A.data,
T.tvm_stack_make_shape(1, dtype="handle"),
T.reinterpret(T.uint64(0), dtype="handle"),
T.uint32(1),
T.Cast("float32", 0),
0,
dtype="handle",
)
)

@T.prim_func
def subroutine(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.evaluate(A.data)

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)

@T.prim_func
def subroutine(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
T.evaluate(A_data)

return mod


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit ed4c9f9

Please sign in to comment.