Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail early before running invalid dynamic graphs #5856

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/tvm/relay/backend/vm.py
Expand Up @@ -27,6 +27,7 @@
import tvm.runtime.vm as vm_rt
from tvm import autotvm
from tvm.relay import expr as _expr
from tvm.relay.ty import type_has_any
from tvm.relay.backend.interpreter import Executor
from . import _vm

Expand Down Expand Up @@ -253,6 +254,12 @@ def _make_executor(self, expr=None):

def _vm_wrapper(*args, **kwargs):
args = self._convert_args(main, args, kwargs)
ret_type = self.mod["main"].checked_type.ret_type
if type_has_any(ret_type) and "llvm" not in str(self.target) and "arm" not in str(
self.target):
raise ValueError(
"Virtual Machine only supports dynamic graphs on CPU, got output type",
ret_type, "on target", self.target)
return self.vm.run(*args)

return _vm_wrapper
3 changes: 3 additions & 0 deletions python/tvm/relay/build_module.py
Expand Up @@ -354,6 +354,9 @@ def _make_executor(self, expr=None):
if expr:
self.mod["main"] = expr
ret_type = self.mod["main"].checked_type.ret_type
if _ty.type_has_any(ret_type):
raise ValueError("Graph Runtime only supports static graphs, got output type",
ret_type)
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
graph_json, mod, params = build(self.mod, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
Expand Down