Skip to content

Commit

Permalink
[Bugfix][Relax] Remove call to tvm.build for empty TIR module (#16561)
Browse files Browse the repository at this point in the history
Prior to this commit, if a lowered `IRModule` does not contain any TIR
functions, `tvm.relax.build` provided an empty `tir_mod`, which caused
a segfault during TIR compilation.  This could occur when
`tvm.relax.build` is called without an explicit target argument, for a
module that does not define any virtual devices.

This commit updates the `_filter_tir` utility function to return
`None` if there are no TIR functions, rather than an empty
`IRModule`.  In addition, checks for an empty `IRModule` are added to
`tvm.build` and `TIRToRuntime`, so that a similar failure mode would
raise an exception rather than producing a segfault.
  • Loading branch information
Lunderberg committed Feb 14, 2024
1 parent c5aaa99 commit 7336deb
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 8 deletions.
3 changes: 3 additions & 0 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def build(
elif isinstance(inputs, PrimFunc):
input_mod = lower(inputs, name=name)
elif isinstance(inputs, tvm.IRModule):
assert (
len(inputs.get_global_vars()) > 0
), "Expected a non-empty IRModule, but the IRModule contained no functions."
input_mod = lower(inputs)
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError(
Expand Down
16 changes: 8 additions & 8 deletions python/tvm/relax/vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _vmlink(
if ext_libs is None:
ext_libs = []
lib = None
if tir_mod is not None:
if tir_mod is not None and len(tir_mod.get_global_vars()) > 0:
lib = tvm.build(
tir_mod,
target=target,
Expand Down Expand Up @@ -348,10 +348,10 @@ def _extract_attrs(mod: tvm.IRModule):
)


def _filter_tir(mod: tvm.IRModule) -> tvm.IRModule:
tir_mod = IRModule({})
tir_mod = tir_mod.with_attrs(mod.attrs)
for gv in mod.get_global_vars():
if isinstance(mod[gv], PrimFunc):
tir_mod[gv] = mod[gv]
return tir_mod
def _filter_tir(mod: tvm.IRModule) -> Optional[tvm.IRModule]:
tir_mod = {gvar: func for gvar, func in mod.functions.items() if isinstance(func, PrimFunc)}

if tir_mod:
return IRModule(tir_mod, attrs=mod.attrs)
else:
return None
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ void CheckAndUpdateHostConsistency(Map<Target, IRModule>* targets, Target* host)

runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
const Target& target_host_arg) {
CHECK(inputs_arg.size()) << "TIRToRuntime expects at least one IRModule as input.";
std::vector<runtime::Module> device_modules;
Map<Target, IRModule> inputs = inputs_arg;
Target target_host = target_host_arg;
Expand Down
20 changes: 20 additions & 0 deletions tests/python/relax/test_vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,26 @@ def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")):
tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7)


def test_vm_compile_without_target_arg(exec_mode):
"""Like test_vm_compile_simple, but with a default target"""

@tvm.script.ir_module
class mod:
@R.function
def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")):
z = R.call_pure_packed(
"test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))
)
return y

ex = relax.build(mod, exec_mode=exec_mode)
inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))
inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))
vm = relax.VirtualMachine(ex, tvm.cpu())
vm["foo"](inp1, inp2)
tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7)


def test_match_check(exec_mode):
@tvm.script.ir_module
class TestMatchCheck:
Expand Down

0 comments on commit 7336deb

Please sign in to comment.