diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 9270b356ba22..2de831e8ad0c 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -108,12 +108,12 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g HostDeviceSplitter splitter(device_mod, name_prefix); - auto body = splitter(func->body); - - if (!body.same_as(func->body)) { + if (auto body = splitter(func->body); !body.same_as(func->body)) { func.CopyOnWrite()->body = body; - auto target_host = target->GetHost().value_or(Target("llvm")); - func = WithAttr(std::move(func), tvm::attr::kTarget, target_host); + } + + if (auto target_host = target->GetHost()) { + func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value()); } return func; diff --git a/tests/cpp/c_codegen_test.cc b/tests/cpp/c_codegen_test.cc index e764d21505d4..a01921239a9f 100644 --- a/tests/cpp/c_codegen_test.cc +++ b/tests/cpp/c_codegen_test.cc @@ -121,5 +121,11 @@ TEST(CCodegen, FunctionOrder) { auto module = build(inputs, Target()); Array func_array = module->GetFunction("get_func_names", false)(); std::vector functions{func_array.begin(), func_array.end()}; - EXPECT_THAT(functions, ElementsAre(StrEq("op_1"), _, StrEq("op_2"), _)); + // The entry point is handled separately from the other functions. + functions.erase(std::remove_if(functions.begin(), functions.end(), + [](const std::string& name) { + return name == tvm::runtime::symbol::tvm_module_main; + }), + functions.end()); + EXPECT_TRUE(std::is_sorted(functions.begin(), functions.end())); } diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index cf866ae005c8..1599b9a031a0 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -168,5 +168,21 @@ def main_kernel(n: T.int32): return mod +class TestSplitHostDevice(BaseCompare): + """Like TestSplitHostDevice, but no device regions to extract + + Even if there are no device regions, the host-side function should + still have its "target" attribute updated. + """ + + def before(): + T.func_attr({"target": T.target("ext_dev", host="llvm")}) + T.evaluate(0) + + def expected(): + T.func_attr({"target": T.target("llvm")}) + T.evaluate(0) + + if __name__ == "__main__": tvm.testing.main()