Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][Frontend][NN] Op print_ #15604

Merged
merged 2 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,15 @@ def finalize(self) -> List[rx.Var]:

def print_(self, tensor: Tensor) -> None:
"""Encloses the side effect of NDArray printing"""
raise NotImplementedError
self.effect = rx.BlockBuilder.current().emit(
rx.call_pure_packed(
rx.extern("effect.print"),
self.effect,
tensor._expr, # pylint: disable=protected-access
sinfo_args=[rx.ObjectStructInfo()],
),
name_hint=self.effect.name_hint,
)


@register_func("effect.print")
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ...block_builder import BlockBuilder
from ...struct_info import TensorStructInfo, TupleStructInfo
from .core import Tensor
from .spec import SpecBuilder

IntExpr = Union[int, _tir.PrimExpr]

Expand Down Expand Up @@ -800,3 +801,7 @@ def _convert(arg):
),
name=name_hint,
)


def print_(array: Tensor):
SpecBuilder.current().io_effect.print_(array)
38 changes: 38 additions & 0 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
import pytest
import torch
import sys
import io

import tvm
import tvm.testing
Expand Down Expand Up @@ -304,5 +307,40 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Ten
tvm.ir.assert_structural_equal(irmodule, Expected)


def test_print():
class Model(Module):
def test(self, x: Tensor):
z = op.add(x, x)
op.print_(z)
return x

# fmt: off
@I.ir_module
class Expected:
@R.function
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
lv: R.Tuple(R.Object) = (_io,)
gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv

@R.function
def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)):
with R.dataflow():
add: R.Tensor((10, 10), dtype="float32") = R.add(x, x)
_io1: R.Object = R.call_pure_packed("effect.print", _io, add, sinfo_args=(R.Object(),))
gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = x, (_io1,)
R.output(gv1)
return gv1
# fmt: on

m = Model()
irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10], "float32")}})

tvm.ir.assert_structural_equal(irmodule["test"], Expected["test"])


if __name__ == "__main__":
tvm.testing.main()