Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphRuntime] Support parameter out in the graph runtime debug #4598

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 6 additions & 8 deletions python/tvm/contrib/debugger/debug_runtime.py
Expand Up @@ -85,7 +85,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
Parameters
----------
module : Module
The interal tvm module that holds the actual graph functions.
The internal tvm module that holds the actual graph functions.

ctx : TVMContext
The context this module is under.
Expand Down Expand Up @@ -188,7 +188,7 @@ def _run_debug(self):
out_tensor = array(out_tensor)
self.debug_datum._output_tensor_list.append(out_tensor)

def debug_get_output(self, node, out):
def debug_get_output(self, node, out=None):
"""Run graph up to node and get the output to out

Parameters
Expand All @@ -199,12 +199,11 @@ def debug_get_output(self, node, out):
out : NDArray
The output array container
"""
ret = None
if isinstance(node, str):
output_tensors = self.debug_datum.get_output_tensors()
try:
ret = output_tensors[node]
except:
out = output_tensors[node]
except KeyError:
node_list = output_tensors.keys()
raise RuntimeError(
"Node "
Expand All @@ -215,10 +214,10 @@ def debug_get_output(self, node, out):
)
elif isinstance(node, int):
output_tensors = self.debug_datum._output_tensor_list
ret = output_tensors[node]
out = output_tensors[node]
else:
raise RuntimeError("Require node index or name only.")
return ret
return out

def run(self, **input_dict):
"""Run forward execution of the graph with debug
Expand All @@ -244,7 +243,6 @@ def run_individual(self, number, repeat=1, min_repeat_ms=0):
ret = self._run_individual(number, repeat, min_repeat_ms)
return ret.strip(",").split(",") if ret else []


def exit(self):
"""Exits the dump folder and all its contents"""
self._remove_dump_root()
10 changes: 6 additions & 4 deletions python/tvm/contrib/graph_runtime.py
Expand Up @@ -22,6 +22,7 @@
from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base


def create(graph_json_str, libmod, ctx):
"""Create a runtime executor module given a graph and module.
Parameters
Expand Down Expand Up @@ -57,6 +58,7 @@ def create(graph_json_str, libmod, ctx):

return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))


def get_device_ctx(libmod, ctx):
"""Parse and validate all the device context(s).
Parameters
Expand Down Expand Up @@ -112,12 +114,12 @@ class GraphModule(object):
Parameters
----------
module : Module
The interal tvm module that holds the actual graph functions.
The internal tvm module that holds the actual graph functions.

Attributes
----------
module : Module
The interal tvm module that holds the actual graph functions.
The internal tvm module that holds the actual graph functions.
"""

def __init__(self, module):
Expand All @@ -142,7 +144,7 @@ def set_input(self, key=None, value=None, **params):
The input key

params : dict of str to NDArray
Additonal arguments
Additional arguments
"""
if key is not None:
self._get_input(key).copyfrom(value)
Expand Down Expand Up @@ -211,7 +213,7 @@ def get_output(self, index, out=None):
return self._get_output(index)

def debug_get_output(self, node, out):
"""Run graph upto node and get the output to out
"""Run graph up to node and get the output to out

Parameters
----------
Expand Down