Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,31 @@ def to(self, dtype: Optional[str] = None) -> None: # pylint: disable=invalid-na
def export_tvm(
self,
spec: "_spec.ModuleSpecType",
debug: bool = False,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also add this debug flag to the jit method?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually you will need the debug flag in line 405

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, now added.

) -> Tuple[IRModule, List[Tuple[str, Parameter]]]:
"""Export the module to TVM IRModule and parameters"""
"""Export the module to TVM IRModule and parameters

Parameters
----------
spec : _spec.ModuleSpecType
A dictionary mapping each input name to a specification
that defines the inputs shape and dtype.
debug : bool
If set to True, then the exported module will support
effects. This enables things like printing in the graph.

Returns
-------
irmodule : tvm.ir.IRModule
The converted tvm IR representation of the model.
params : Dict[str, tvm.nd.array]
A dictionary of parameters corresponding to the weights of
the model.
"""
from . import spec as _spec # pylint: disable=import-outside-toplevel

spec = _spec.ModuleSpec.from_raw(spec, self)
mod, params = _spec.SpecBuilder().build(spec)
mod, params = _spec.SpecBuilder().build(spec, debug=debug)
return mod, params

def jit( # pylint: disable=too-many-arguments
Expand All @@ -375,6 +394,7 @@ def jit( # pylint: disable=too-many-arguments
device: str = "cpu",
pipeline: str = "zero",
out_format: str = "torch",
debug: bool = False,
) -> Callable:
"""Just-in-time compilation of a nn.model to an executable"""
from tvm import relax # pylint: disable=import-outside-toplevel
Expand All @@ -383,7 +403,7 @@ def jit( # pylint: disable=too-many-arguments

# Convert nn.Module to IRModule
spec = _spec.ModuleSpec.from_raw(spec, self)
mod, params = _spec.SpecBuilder().build(spec)
mod, params = _spec.SpecBuilder().build(spec, debug=debug)

# Convert parameters
device = _str_to_device(device)
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,4 +1290,6 @@ def _convert(arg):


def print_(array: Tensor):
if SpecBuilder.current().io_effect is None:
raise RuntimeError("Printing is only supported when debug mode is on.")
SpecBuilder.current().io_effect.print_(array)
39 changes: 26 additions & 13 deletions python/tvm/relax/frontend/nn/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import defaultdict
import inspect
import threading
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, Optional

from tvm import tir
from tvm.ir import IRModule
Expand Down Expand Up @@ -288,7 +288,9 @@ def __exit__(self, exc_type, exc, traceback) -> None:
assert hasattr(SpecBuilder._tls, "current")
delattr(SpecBuilder._tls, "current")

def build(self, spec: ModuleSpec) -> Tuple[IRModule, List[Tuple[str, core.Parameter]]]:
def build(
self, spec: ModuleSpec, debug: bool = False
) -> Tuple[IRModule, List[Tuple[str, core.Parameter]]]:
"""Build the ModuleSpec to TVM IRModule. Returns the IRModule and the parameters."""

# pylint: disable=protected-access
Expand All @@ -301,7 +303,9 @@ def _params() -> List[Tuple[str, core.Parameter]]:
return params

def _effects() -> List[Tuple[str, core.Effect]]:
result = [("", self.io_effect)]
result = []
if self.io_effect is not None:
result.append(("", self.io_effect))
for name, effect in core._attribute_finder(
spec.module, "", condition_yield=lambda x: isinstance(x, core.Effect)
):
Expand All @@ -321,16 +325,22 @@ def _extern_modules() -> List[Tuple[str, List[str]]]:

# pylint: enable=protected-access

# Disable IO effects if not in debug mode.
if not debug:
self.io_effect = None
params = _params()
effects = _effects()
extern_modules = _extern_modules()
with self:
with self.builder.function("_initialize_effect"):
with self.builder.dataflow():
outputs = _emit_effect_init(self.builder, effects)
self.builder.emit_func_output(outputs, params=[])
if effects:
with self.builder.function("_initialize_effect"):
with self.builder.dataflow():
outputs = _emit_effect_init(self.builder, effects)
self.builder.emit_func_output(outputs, params=[])
for method_name, method_spec in zip(spec.method_names, spec.method_specs):
with self.builder.function(method_name):
with self.builder.function(
method_name, attrs={"num_input": len(method_spec.arg_specs) + len(effects)}
):
with self.builder.dataflow():
outputs, inputs = _emit_method(self.builder, method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)
Expand Down Expand Up @@ -363,7 +373,7 @@ def _emit_method(
builder: BlockBuilder,
spec: MethodSpec,
params: List[Tuple[str, core.Parameter]],
effects: List[Tuple[str, core.Effect]],
effects: Optional[List[Tuple[str, core.Effect]]],
):
# pylint: disable=protected-access
def _unwrap_ret(expr: Any) -> Any:
Expand All @@ -386,16 +396,19 @@ def _convert_input(arg):
inputs = []
for arg in explicit_inputs:
inputs.append(_convert_input(arg))
for name, effect in effects:
inputs.extend(effect.create(name))
for name, param in params:
param._expr = core._tensor_placeholder(name, param.shape, param.dtype)._expr
inputs.append(param._expr)
for name, effect in effects:
inputs.extend(effect.create(name))
# pylint: enable=protected-access
# pylint: enable=protected-access

outputs = spec.method(*explicit_inputs)
effect_outputs = []
for _, effect in effects:
effect_outputs.extend(effect.finalize())
outputs = builder.emit_output(rx.Tuple([_unwrap_ret(outputs), rx.Tuple(effect_outputs)]))
if effect_outputs:
outputs = builder.emit_output(rx.Tuple([_unwrap_ret(outputs), rx.Tuple(effect_outputs)]))
else:
outputs = builder.emit_output(_unwrap_ret(outputs))
return outputs, inputs
10 changes: 3 additions & 7 deletions tests/python/relax/test_frontend_nn_extern_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,25 +92,21 @@ def forward(self, a: nn.Tensor, b: nn.Tensor):
def forward(
a_1: R.Tensor(("a", "b", "c", "d", 1, 2, 3, 4), dtype="float32"),
b_1: R.Tensor(("c", "d", "e", "f", 5, 6, 7, 8), dtype="float32"),
_io: R.Object,
) -> R.Tuple(
R.Tensor(("a", "b", "c", "d", "e", "f", 9, 10), dtype="float32"), R.Tuple(R.Object)
):
) -> R.Tensor(("a", "b", "c", "d", "e", "f", 9, 10), dtype="float32"):
a = T.int64()
b = T.int64()
c = T.int64()
d = T.int64()
e = T.int64()
f = T.int64()
R.func_attr({"num_input": 2})
with R.dataflow():
matmul = R.call_dps_packed(
"matmul",
(a_1, b_1),
out_sinfo=R.Tensor((a, b, c, d, e, f, 9, 10), dtype="float32"),
)
gv1: R.Tuple(
R.Tensor((a, b, c, d, e, f, 9, 10), dtype="float32"), R.Tuple(R.Object)
) = matmul, (_io,)
gv1: R.Tensor((a, b, c, d, e, f, 9, 10), dtype="float32") = matmul
R.output(gv1)
return gv1

Expand Down
Loading