Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Move SplitHostDevice to before MakePackedAPI #14986

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 9 additions & 3 deletions python/tvm/utils/roofline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,16 @@ def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=Non

@pass_instrument
class SaveLoweredTIR:
"""Save TIR functions from right before final lowering. Right now this
means right before tir.MakePackedAPI."""
"""Save TIR functions for analysis.

def __init__(self, before_pass: str = "tir.MakePackedAPI"):
We need the TIR function in a form that can be handled by
`auto_scheduler.feature.named_features_from_primfunc`, but which
is the closest to the final lowered form as possible. Right now this
means right before tir.SplitHostDevice.

"""

def __init__(self, before_pass: str = "tir.SplitHostDevice"):
"""
Parameters
----------
Expand Down
5 changes: 3 additions & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InjectPTXLDG32());
}

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());

bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
.value_or(relay::Executor::Create("graph", {}))
->GetAttr<Bool>("unpacked-api")
Expand All @@ -590,8 +593,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

return transform::Sequential(mixed_pass_list);
Expand Down
11 changes: 10 additions & 1 deletion src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ PrimFunc MakePackedAPI(PrimFunc func) {
}();
int target_device_type = target->GetTargetDeviceType();

// A function without a host target has already been lowered.
Target target_host;
if (auto opt = target->GetHost()) {
target_host = opt.value();
} else {
return func;
}

auto* func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
Expand Down Expand Up @@ -325,7 +333,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
name_hint + "." + kv.first->name_hint);
}

func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host}});

Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
Expand Down
10 changes: 9 additions & 1 deletion src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
}();
int target_device_type = target->GetTargetDeviceType();

// A function without a host target has already been lowered.
Target target_host;
if (auto opt = target->GetHost()) {
target_host = opt.value();
} else {
return func;
}

auto* func_ptr = func.CopyOnWrite();

// Setup device context
Expand Down Expand Up @@ -145,7 +153,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
func_ptr->buffer_map = Map<Var, Buffer>();

// return the function.
return func;
return WithAttrs(std::move(func), {{tvm::attr::kTarget, target_host}});
}

namespace transform {
Expand Down
8 changes: 0 additions & 8 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ class HostDeviceSplitter : public StmtMutator {
};

PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& gvar) {
auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt_target) << "SplitHostDevice: Require the target attribute";
Target target = opt_target.value();

auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto name_prefix = global_symbol.value_or(gvar->name_hint);

Expand All @@ -112,10 +108,6 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g
func.CopyOnWrite()->body = body;
}

if (auto target_host = target->GetHost()) {
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value());
}

return func;
}

Expand Down
4 changes: 3 additions & 1 deletion tests/python/contrib/test_ethosu/test_encode_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,9 @@ def get_graph():
# nothing else was overrwritten.
# With Target Hooks the TIR module needs a target attached
# and lowered via make unpacked API.
tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u"))
tir_mod["main"] = tir_mod["main"].with_attr(
"target", tvm.target.Target("ethos-u", host="ethos-u")
)
tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
tir_to_cs_translator.translate(tir_mod, params)

Expand Down
8 changes: 6 additions & 2 deletions tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ def test_buffer_info_extraction():
# With Target Hooks the TIR module needs a target attached
# and lowered via make unpacked API.
tir_mod = test_case["tir_module"]
tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u"))
tir_mod["main"] = tir_mod["main"].with_attr(
"target", tvm.target.Target("ethos-u", host="ethos-u")
)
tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"])
for buffer_var, info in buffer_info.items():
Expand Down Expand Up @@ -959,7 +961,9 @@ def check_buffer(address, region, length, buffer_var):

for test_case in test_cases:
tir_mod = test_case["tir_module"]
tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u"))
tir_mod["main"] = tir_mod["main"].with_attr(
"target", tvm.target.Target("ethos-u", host="ethos-u")
)
tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
candidate_regions_for_scratch = [5, 2, 1]
(
Expand Down
40 changes: 33 additions & 7 deletions tests/python/unittest/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_makeapi():
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr(
{
"target": tvm.target.Target("llvm"),
"target": tvm.target.Target("llvm", host="llvm"),
"global_symbol": "main",
}
)
Expand Down Expand Up @@ -90,7 +90,9 @@ def test_variable_passed_from_args():
stmt = ib.get()

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, not_device_context], stmt))
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod)
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
)(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
func = tvm.tir.transform.MakePackedAPI()(mod)["main"]

Expand Down Expand Up @@ -132,7 +134,9 @@ def test_device_api_context_implicit_resource_handle():
stmt = ib.get()

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, device_context], stmt))
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod)
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
)(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
func = tvm.tir.transform.MakePackedAPI()(mod)["main"]

Expand Down Expand Up @@ -161,7 +165,7 @@ def test_device_api_context_implicit_resource_handle():

@pytest.mark.parametrize("use_global_symbol", [True, False])
def test_no_op_when_global_symbol_is_absent(use_global_symbol):
func_attr = {"target": tvm.target.Target("llvm")}
func_attr = {"target": tvm.target.Target("llvm", host="llvm")}
if use_global_symbol:
func_attr["global_symbol"] = "main"

Expand All @@ -177,6 +181,28 @@ def before():
tvm.ir.assert_structural_equal(before, after)


def test_target_host_removed():
"""After MakePackedAPI, host-side target should be the host

MakePackedAPI is the last transform that requires both the device
and the host. After MakePackedAPI, the target attribute should
only contain the host-side target.
"""

host = tvm.target.Target("llvm")

@I.ir_module
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)})
T.evaluate(0)

after = tvm.tir.transform.MakePackedAPI()(before)
target_attr = after["main"].attrs["target"]
assert str(host) == str(target_attr)


def test_internal_subroutine_call():
"""Internal subroutines should not use the PackedFunc API

Expand All @@ -190,7 +216,7 @@ def test_internal_subroutine_call():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
before.subroutine(A.data)

@T.prim_func
Expand Down Expand Up @@ -222,12 +248,12 @@ def test_subroutine_call_to_externally_visible_subroutine():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
before.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")})
T.evaluate(A_data)

after = tvm.tir.transform.MakePackedAPI()(before)
Expand Down
65 changes: 56 additions & 9 deletions tests/python/unittest/test_tir_transform_make_unpacked_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def mod(mod_without_attrs):


def test_noop_if_not_global_symbol(mod_without_attrs):
before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(
mod_without_attrs
)
target = tvm.target.Target("llvm", host="llvm")
before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_without_attrs)
after = tvm.tir.transform.MakeUnpackedAPI()(before)
tvm.ir.assert_structural_equal(before, after)

Expand All @@ -59,7 +58,8 @@ def test_fails_if_no_target(mod_without_attrs):

@tvm.testing.parametrize_targets("c", "llvm", "cuda")
def test_device_setup(mod, target, dev):
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod)
target = tvm.target.Target(target, host="llvm")
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
assert len(f.params) == 1
assert f.params[0].name == "A"
Expand Down Expand Up @@ -138,6 +138,49 @@ def test_body():
assert f.params[2].name == "A"


class TestTargetHostRemoved(tvm.testing.CompareBeforeAfter):
"""After MakeUnpackedAPI, host-side target should be the host

MakeUnpackedAPI is the last transform that requires both the device
and the host. After MakeUnpackedAPI, the target attribute should
only contain the host-side target.
"""

transform = tvm.tir.transform.MakeUnpackedAPI()

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")})
mod.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("cuda")})
T.evaluate(A_data)

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 2)
mod.subroutine(A_data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("cuda")})
T.evaluate(A_data)

return mod


class TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter):
"""Internal subroutines do not require modification

Expand All @@ -153,7 +196,7 @@ def before(self):
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
mod.subroutine(A.data)

@T.prim_func
Expand Down Expand Up @@ -195,12 +238,14 @@ def before(self):
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
mod.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.func_attr(
{"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")}
)
T.evaluate(A_data)

return mod
Expand Down Expand Up @@ -240,7 +285,7 @@ def before(self):
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
mod.subroutine(
T.tvm_stack_make_array(
A.data,
Expand All @@ -255,7 +300,9 @@ def main(A: T.Buffer(1, "float32")):

@T.prim_func
def subroutine(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.func_attr(
{"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")}
)
T.evaluate(A.data)

return mod
Expand Down
14 changes: 7 additions & 7 deletions tests/python/unittest/test_tir_transform_split_host_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_split_host_device_func_attr():
[
tvm.tir.transform.AnnotateDeviceRegions(),
tvm.tir.transform.SplitHostDevice(),
tvm.tir.transform.MakePackedAPI(),
tvm.tir.transform.LowerDeviceKernelLaunch(),
]
)(mod)
Expand Down Expand Up @@ -111,7 +112,7 @@ def expected(self):
class mod:
@T.prim_func
def main(n: T.int32):
T.func_attr({"target": T.target("llvm -opt-level=0")})
T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")})
mod.main_kernel(n)

@T.prim_func
Expand Down Expand Up @@ -168,20 +169,19 @@ def main_kernel(n: T.int32):
return mod


class TestSplitHostDevice(BaseCompare):
class TestSplitHostDeviceWithoutDeviceRegion(BaseCompare):
"""Like TestSplitHostDevice, but no device regions to extract

Even if there are no device regions, the host-side function should
still have its "target" attribute updated.
Because MakePackedAPI/MakeUnpackedAPI still require both the
device and host, SplitHostDevice does not modify the "target"
attribute.
"""

def before():
T.func_attr({"target": T.target("ext_dev", host="llvm")})
T.evaluate(0)

def expected():
T.func_attr({"target": T.target("llvm")})
T.evaluate(0)
expected = before


if __name__ == "__main__":
Expand Down