Skip to content

Commit

Permalink
[TIR] Move SplitHostDevice to before MakePackedAPI
Browse files Browse the repository at this point in the history
This simplifies the logic used in MakePackedAPI, that it the last user
of the host parameter in a function's target.  After MakePackedAPI,
every PrimFunc has a "target" attribute without a "host".
  • Loading branch information
Lunderberg committed Jun 5, 2023
1 parent 80079b6 commit c31f2b0
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 35 deletions.
5 changes: 3 additions & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InjectPTXLDG32());
}

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());

bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
.value_or(relay::Executor::Create("graph", {}))
->GetAttr<Bool>("unpacked-api")
Expand All @@ -590,8 +593,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

return transform::Sequential(mixed_pass_list);
Expand Down
11 changes: 10 additions & 1 deletion src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ PrimFunc MakePackedAPI(PrimFunc func) {
}();
int target_device_type = target->GetTargetDeviceType();

// A function without a host target has already been lowered.
Target target_host;
if (auto opt = target->GetHost()) {
target_host = opt.value();
} else {
return func;
}

auto* func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
Expand Down Expand Up @@ -325,7 +333,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
name_hint + "." + kv.first->name_hint);
}

func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host}});

Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
Expand Down
10 changes: 9 additions & 1 deletion src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
}();
int target_device_type = target->GetTargetDeviceType();

// A function without a host target has already been lowered.
Target target_host;
if (auto opt = target->GetHost()) {
target_host = opt.value();
} else {
return func;
}

auto* func_ptr = func.CopyOnWrite();

// Setup device context
Expand Down Expand Up @@ -145,7 +153,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
func_ptr->buffer_map = Map<Var, Buffer>();

// return the function.
return func;
return WithAttrs(std::move(func), {{tvm::attr::kTarget, target_host}});
}

namespace transform {
Expand Down
8 changes: 0 additions & 8 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ class HostDeviceSplitter : public StmtMutator {
};

PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& gvar) {
auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt_target) << "SplitHostDevice: Require the target attribute";
Target target = opt_target.value();

auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto name_prefix = global_symbol.value_or(gvar->name_hint);

Expand All @@ -112,10 +108,6 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g
func.CopyOnWrite()->body = body;
}

if (auto target_host = target->GetHost()) {
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value());
}

return func;
}

Expand Down
40 changes: 33 additions & 7 deletions tests/python/unittest/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_makeapi():
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr(
{
"target": tvm.target.Target("llvm"),
"target": tvm.target.Target("llvm", host="llvm"),
"global_symbol": "main",
}
)
Expand Down Expand Up @@ -90,7 +90,9 @@ def test_variable_passed_from_args():
stmt = ib.get()

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, not_device_context], stmt))
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod)
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
)(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
func = tvm.tir.transform.MakePackedAPI()(mod)["main"]

Expand Down Expand Up @@ -132,7 +134,9 @@ def test_device_api_context_implicit_resource_handle():
stmt = ib.get()

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, device_context], stmt))
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod)
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
)(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
func = tvm.tir.transform.MakePackedAPI()(mod)["main"]

Expand Down Expand Up @@ -161,7 +165,7 @@ def test_device_api_context_implicit_resource_handle():

@pytest.mark.parametrize("use_global_symbol", [True, False])
def test_no_op_when_global_symbol_is_absent(use_global_symbol):
func_attr = {"target": tvm.target.Target("llvm")}
func_attr = {"target": tvm.target.Target("llvm", host="llvm")}
if use_global_symbol:
func_attr["global_symbol"] = "main"

Expand All @@ -177,6 +181,28 @@ def before():
tvm.ir.assert_structural_equal(before, after)


def test_target_host_removed():
"""After MakePackedAPI, host-side target should be the host
MakePackedAPI is the last transform that requires both the device
and the host. After MakePackedAPI, the target attribute should
only contain the host-side target.
"""

host = tvm.target.Target("llvm")

@I.ir_module
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)})
T.evaluate(0)

after = tvm.tir.transform.MakePackedAPI()(before)
target_attr = after["main"].attrs["target"]
assert str(host) == str(target_attr)


def test_internal_subroutine_call():
"""Internal subroutines should not use the PackedFunc API
Expand All @@ -190,7 +216,7 @@ def test_internal_subroutine_call():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
before.subroutine(A.data)

@T.prim_func
Expand Down Expand Up @@ -222,12 +248,12 @@ def test_subroutine_call_to_externally_visible_subroutine():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
before.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")})
T.evaluate(A_data)

after = tvm.tir.transform.MakePackedAPI()(before)
Expand Down
65 changes: 56 additions & 9 deletions tests/python/unittest/test_tir_transform_make_unpacked_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def mod(mod_without_attrs):


def test_noop_if_not_global_symbol(mod_without_attrs):
before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(
mod_without_attrs
)
target = tvm.target.Target("llvm", host="llvm")
before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_without_attrs)
after = tvm.tir.transform.MakeUnpackedAPI()(before)
tvm.ir.assert_structural_equal(before, after)

Expand All @@ -59,7 +58,8 @@ def test_fails_if_no_target(mod_without_attrs):

@tvm.testing.parametrize_targets("c", "llvm", "cuda")
def test_device_setup(mod, target, dev):
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod)
target = tvm.target.Target(target, host="llvm")
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
assert len(f.params) == 1
assert f.params[0].name == "A"
Expand Down Expand Up @@ -138,6 +138,49 @@ def test_body():
assert f.params[2].name == "A"


class TestTargetHostRemoved(tvm.testing.CompareBeforeAfter):
"""After MakeUnpackedAPI, host-side target should be the host
MakeUnpackedAPI is the last transform that requires both the device
and the host. After MakeUnpackedAPI, the target attribute should
only contain the host-side target.
"""

transform = tvm.tir.transform.MakeUnpackedAPI()

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")})
mod.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("cuda")})
T.evaluate(A_data)

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 2)
mod.subroutine(A_data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("cuda")})
T.evaluate(A_data)

return mod


class TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter):
"""Internal subroutines do not require modification
Expand All @@ -153,7 +196,7 @@ def before(self):
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
mod.subroutine(A.data)

@T.prim_func
Expand Down Expand Up @@ -195,12 +238,14 @@ def before(self):
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
mod.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.func_attr(
{"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")}
)
T.evaluate(A_data)

return mod
Expand Down Expand Up @@ -240,7 +285,7 @@ def before(self):
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
mod.subroutine(
T.tvm_stack_make_array(
A.data,
Expand All @@ -255,7 +300,9 @@ def main(A: T.Buffer(1, "float32")):

@T.prim_func
def subroutine(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.func_attr(
{"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")}
)
T.evaluate(A.data)

return mod
Expand Down
14 changes: 7 additions & 7 deletions tests/python/unittest/test_tir_transform_split_host_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_split_host_device_func_attr():
[
tvm.tir.transform.AnnotateDeviceRegions(),
tvm.tir.transform.SplitHostDevice(),
tvm.tir.transform.MakePackedAPI(),
tvm.tir.transform.LowerDeviceKernelLaunch(),
]
)(mod)
Expand Down Expand Up @@ -111,7 +112,7 @@ def expected(self):
class mod:
@T.prim_func
def main(n: T.int32):
T.func_attr({"target": T.target("llvm -opt-level=0")})
T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")})
mod.main_kernel(n)

@T.prim_func
Expand Down Expand Up @@ -168,20 +169,19 @@ def main_kernel(n: T.int32):
return mod


class TestSplitHostDevice(BaseCompare):
class TestSplitHostDeviceWithoutDeviceRegion(BaseCompare):
"""Like TestSplitHostDevice, but no device regions to extract
Even if there are no device regions, the host-side function should
still have its "target" attribute updated.
Because MakePackedAPI/MakeUnpackedAPI still require both the
device and host, SplitHostDevice does not modify the "target"
attribute.
"""

def before():
T.func_attr({"target": T.target("ext_dev", host="llvm")})
T.evaluate(0)

def expected():
T.func_attr({"target": T.target("llvm")})
T.evaluate(0)
expected = before


if __name__ == "__main__":
Expand Down

0 comments on commit c31f2b0

Please sign in to comment.