diff --git a/src/runtime/module.cc b/src/runtime/module.cc index ac2b60f8a383..4cec5e3643c1 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -68,6 +68,9 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) if (query_imports) { for (Module& m : self->imports_) { pf = m.operator->()->GetFunction(name, query_imports); + if (pf != nullptr) { + return pf; + } } } return pf; diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 56ebb29c7c65..64f87fb3c561 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -538,6 +538,35 @@ def test_debug_graph_runtime(): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +def test_multiple_imported_modules(): + def make_func(symbol): + n = tvm.te.size_var("n") + Ab = tvm.tir.decl_buffer((n,), dtype="float32") + i = tvm.te.var("i") + stmt = tvm.tir.For( + i, + 0, + n - 1, + 0, + 0, + tvm.tir.Store(Ab.data, tvm.tir.Load("float32", Ab.data, i) + 1, i + 1), + ) + return tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", symbol) + + def make_module(mod): + mod = tvm.IRModule(mod) + mod = tvm.driver.build(mod, target="llvm") + return mod + + module_main = make_module({"main": make_func("main")}) + module_a = make_module({"func_a": make_func("func_a")}) + module_b = make_module({"func_b": make_func("func_b")}) + module_main.import_module(module_a) + module_main.import_module(module_b) + module_main.get_function("func_a", query_imports=True) + module_main.get_function("func_b", query_imports=True) + + if __name__ == "__main__": test_legacy_compatibility() test_cpu() @@ -545,3 +574,4 @@ def test_debug_graph_runtime(): test_mod_export() test_remove_package_params() test_debug_graph_runtime() + test_multiple_imported_modules()