diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 35cad1285c1b..b0b72a8d625e 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -53,7 +53,15 @@ def finalize(self) -> List[rx.Var]: def print_(self, tensor: Tensor) -> None: """Encloses the side effect of NDArray printing""" - raise NotImplementedError + self.effect = rx.BlockBuilder.current().emit( + rx.call_pure_packed( + rx.extern("effect.print"), + self.effect, + tensor._expr, # pylint: disable=protected-access + sinfo_args=[rx.ObjectStructInfo()], + ), + name_hint=self.effect.name_hint, + ) @register_func("effect.print") diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 8dea54c72f1c..b6e6578ce67f 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -25,6 +25,7 @@ from ...block_builder import BlockBuilder from ...struct_info import TensorStructInfo, TupleStructInfo from .core import Tensor +from .spec import SpecBuilder IntExpr = Union[int, _tir.PrimExpr] @@ -800,3 +801,7 @@ def _convert(arg): ), name=name_hint, ) + + +def print_(array: Tensor): + SpecBuilder.current().io_effect.print_(array) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 27d7e6d2ff04..048a67110115 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. import pytest +import torch +import sys +import io import tvm import tvm.testing @@ -304,5 +307,40 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Ten tvm.ir.assert_structural_equal(irmodule, Expected) +def test_print(): + class Model(Module): + def test(self, x: Tensor): + z = op.add(x, x) + op.print_(z) + return x + + # fmt: off + @I.ir_module + class Expected: + @R.function + def _initialize_effect() -> R.Tuple(R.Object): + with R.dataflow(): + _io: R.Object = R.null_value() + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv + R.output(gv) + return gv + + @R.function + def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)): + with R.dataflow(): + add: R.Tensor((10, 10), dtype="float32") = R.add(x, x) + _io1: R.Object = R.call_pure_packed("effect.print", _io, add, sinfo_args=(R.Object(),)) + gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = x, (_io1,) + R.output(gv1) + return gv1 + # fmt: on + + m = Model() + irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10], "float32")}}) + + tvm.ir.assert_structural_equal(irmodule["test"], Expected["test"]) + + if __name__ == "__main__": tvm.testing.main()