Skip to content

Commit

Permalink
[TIR] Preserve existing kTarget function attribute in BindTarget (#14942
Browse files Browse the repository at this point in the history
)

* [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI

PRs #14913 and
#14914 made analogous changes to
`MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls.
Both PRs introduced the same symbol,
`tvm::tir::SubroutineCallRewriter`, a local utility to update internal
calls to a modified function.  While each PR passed CI individually,
and was therefore able to merge, having both changes caused a
duplicate symbol.

This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place
their local utilities into anonymous namespaces, avoiding the
conflict.

* [Target] Added WithoutHost method

* [TIR] Preserve existing kTarget function attribute in BindTarget

Previously, if a function already has a `tvm::attr::kTarget`
attribute, it will be overwritten by the `tir.BindTarget` transform.
This commit updates the behavior such that `tir.BindTarget` adds
annotations to functions that are missing a target annotation, but
preserves any existing target annotations.

This is part of a series of commits to simplify the handling of
multi-target builds.
  • Loading branch information
Lunderberg committed May 26, 2023
1 parent 86ba26d commit 81056cc
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 5 deletions.
3 changes: 3 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ class Target : public ObjectRef {
*/
static Target WithHost(const Target& target, const Target& host);

/*! \return The target with the host stripped out */
Target WithoutHost() const;

/*!
* \brief Returns true if \p this target represents an external codegen. If so,
* \p this->kind->name can be used as the "Compiler" attribute on partitioned functions,
Expand Down
10 changes: 10 additions & 0 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,16 @@ Map<String, ObjectRef> TargetNode::Export() const {

Optional<Target> TargetNode::GetHost() const { return this->host.as<Target>(); }

Target Target::WithoutHost() const {
if ((*this)->GetHost()) {
auto output = make_object<TargetNode>(*get());
output->host = NullOpt;
return Target(output);
} else {
return *this;
}
}

int TargetNode::GetTargetDeviceType() const {
if (Optional<Integer> device_type = GetAttr<Integer>("target_device_type")) {
return Downcast<Integer>(device_type)->value;
Expand Down
30 changes: 25 additions & 5 deletions src/tir/transforms/primfunc_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,32 @@ namespace tvm {
namespace tir {
namespace transform {
transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
if (f->GetAttr<Integer>(tvm::tir::attr::kIsHostFunc) == 1) {
return WithAttr(std::move(WithoutAttr(std::move(f), tvm::tir::attr::kIsHostFunc)),
tvm::attr::kTarget, target->host.value_or(Target("llvm")));
Target without_host = target.WithoutHost();
Target target_host = Downcast<Target>(target->host.value_or(Target("llvm")));

auto fpass = [target, target_host, without_host](tir::PrimFunc func, IRModule m,
transform::PassContext ctx) {
bool is_externally_exposed = func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();

if (auto func_target = func->GetAttr<Target>(tvm::attr::kTarget)) {
auto func_target_host = func_target.value()->GetHost();
auto target_host = target->GetHost();

if (target_host && !func_target_host && is_externally_exposed) {
auto new_target = Target::WithHost(func_target.value(), target_host.value());
func = WithAttr(std::move(func), tvm::attr::kTarget, new_target);
}
} else if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) {
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host);
} else if (is_externally_exposed) {
func = WithAttr(std::move(func), tvm::attr::kTarget, target);
} else {
func = WithAttr(std::move(func), tvm::attr::kTarget, without_host);
}
return WithAttr(std::move(f), tvm::attr::kTarget, target);

func = WithoutAttr(std::move(func), tvm::tir::attr::kIsHostFunc);

return func;
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {});
}
Expand Down
112 changes: 112 additions & 0 deletions tests/python/unittest/test_tir_transform_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,118 @@ def test_bind_target():
assert after["func2"].attrs["target"] == target


class TestBindTarget(tvm.testing.CompareBeforeAfter):
"""BindTarget adds the "target" attribute"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))

def before():
T.evaluate(0)

def expected():
T.func_attr({"target": T.target("cuda")})
T.evaluate(0)


class TestBindTargetWithHostToExposedFunction(tvm.testing.CompareBeforeAfter):
"""BindTarget adds the host target to externally-exposed functions"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm"))

def before():
T.func_attr({"global_symbol": "main"})
T.evaluate(0)

def expected():
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")})
T.evaluate(0)


class TestBindTargetWithHostToInternalFunction(tvm.testing.CompareBeforeAfter):
"""Internal functions have a target annotation, but without the host
The host portion of the target annotation provides host
parameters, and is used to expose a function externally as part of
`MakePackedAPI` and `MakeUnpackedAPI`. For internal functions, no
external exposure is required, so the host attribute should not be
used.
"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm"))

def before():
T.evaluate(0)

def expected():
T.func_attr({"target": T.target("cuda")})
T.evaluate(0)


class TestBindTargetIgnoresExisting(tvm.testing.CompareBeforeAfter):
"""BindTarget should not replace existing annotations"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))

def before():
T.func_attr({"target": T.target("nvptx")})
T.evaluate(0)

expected = before


class TestBindTargetUpdatesHost(tvm.testing.CompareBeforeAfter):
"""BindTarget should update host for existing annotations"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm -opt-level=0"))

def before():
T.func_attr({"global_symbol": "func", "target": T.target("nvptx")})
T.evaluate(0)

def expected():
T.func_attr(
{
"global_symbol": "func",
"target": T.target("nvptx", host="llvm -opt-level=0"),
}
)
T.evaluate(0)


class TestBindTargetMultipleFunctions(tvm.testing.CompareBeforeAfter):
"""BindTarget may apply to multiple functions in a module"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))

def before(self):
@tvm.script.ir_module
class mod:
@T.prim_func
def func1():
T.evaluate(0)

@T.prim_func
def func2():
T.evaluate(0)

return mod

def expected(self):
@tvm.script.ir_module
class mod:
@T.prim_func
def func1():
T.func_attr({"target": T.target("cuda")})
T.evaluate(0)

@T.prim_func
def func2():
T.func_attr({"target": T.target("cuda")})
T.evaluate(0)

return mod


def test_filter_primfunc():
mod = MockModule
assert mod
Expand Down

0 comments on commit 81056cc

Please sign in to comment.