Skip to content

Commit

Permalink
[Bugfix][Module] Fix recursive GetFunction in runtime::Module (#6859)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Nov 6, 2020
1 parent 41c776e commit 01e76c2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/runtime/module.cc
Expand Up @@ -68,6 +68,9 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports)
if (query_imports) {
for (Module& m : self->imports_) {
pf = m.operator->()->GetFunction(name, query_imports);
if (pf != nullptr) {
return pf;
}
}
}
return pf;
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_runtime_module_based_interface.py
Expand Up @@ -538,10 +538,40 @@ def test_debug_graph_runtime():
tvm.testing.assert_allclose(out, verify(data), atol=1e-5)


def test_multiple_imported_modules():
def make_func(symbol):
n = tvm.te.size_var("n")
Ab = tvm.tir.decl_buffer((n,), dtype="float32")
i = tvm.te.var("i")
stmt = tvm.tir.For(
i,
0,
n - 1,
0,
0,
tvm.tir.Store(Ab.data, tvm.tir.Load("float32", Ab.data, i) + 1, i + 1),
)
return tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", symbol)

def make_module(mod):
mod = tvm.IRModule(mod)
mod = tvm.driver.build(mod, target="llvm")
return mod

module_main = make_module({"main": make_func("main")})
module_a = make_module({"func_a": make_func("func_a")})
module_b = make_module({"func_b": make_func("func_b")})
module_main.import_module(module_a)
module_main.import_module(module_b)
module_main.get_function("func_a", query_imports=True)
module_main.get_function("func_b", query_imports=True)


if __name__ == "__main__":
test_legacy_compatibility()
test_cpu()
test_gpu()
test_mod_export()
test_remove_package_params()
test_debug_graph_runtime()
test_multiple_imported_modules()

0 comments on commit 01e76c2

Please sign in to comment.