From ea93e01b8b3ef5c524a479eb37064cd8e41ad3e9 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Fri, 17 Apr 2026 21:30:55 +0800 Subject: [PATCH 1/4] ImprovePrint ExternFunc struct_info when non-default --- src/script/printer/relax/function.cc | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index 24ae192c73b3..1a516204e50c 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -22,6 +22,19 @@ namespace tvm { namespace script { namespace printer { +static bool HasDefaultExternFuncStructInfo(const relax::ExternFunc& n) { + const auto* sinfo = n->struct_info_.as(); + if (sinfo == nullptr || sinfo->params.defined() || sinfo->purity || + !sinfo->ret->IsInstance() || !sinfo->derive_func.defined()) { + return false; + } + static const EnvFunc fn = EnvFunc::Get("tvm.relax.struct_info.infer_by_sinfo_args"); + if (!fn.defined()) { + return false; + } + return sinfo->derive_func.value().same_as(fn); +} + bool AtTopLevelFunction(const IRDocsifier& d) { // fewer than 2 frames: not in a function at all if (d->frames.size() < 2) { @@ -128,8 +141,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::ExternFunc n, AccessPath n_p, IRDocsifier d) -> Doc { - // TODO(@junrushao): print more information out of extern function. - return Relax(d, "ExternFunc")->Call({LiteralDoc::Str(n->global_symbol, n_p)}); + ffi::Array args; + args.push_back(LiteralDoc::Str(n->global_symbol, n_p->Attr("global_symbol"))); + if (!HasDefaultExternFuncStructInfo(n)) { + args.push_back(d->AsDoc(n->struct_info_, n_p->Attr("struct_info_"))); + } + return Relax(d, "ExternFunc")->Call(args); }); TVM_SCRIPT_REPR(relax::FunctionNode, ReprPrintRelax); From 24ff3e11098d0b4e1c0d1beb540f0c4d41e5a18a Mon Sep 17 00:00:00 2001 From: cchung100m Date: Fri, 17 Apr 2026 22:52:33 +0800 Subject: [PATCH 2/4] simplify HasDefaultExternFuncStructInfo --- src/script/printer/relax/function.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index 1a516204e50c..c759fa80aeef 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -25,14 +25,10 @@ namespace printer { static bool HasDefaultExternFuncStructInfo(const relax::ExternFunc& n) { const auto* sinfo = n->struct_info_.as(); if (sinfo == nullptr || sinfo->params.defined() || sinfo->purity || - !sinfo->ret->IsInstance() || !sinfo->derive_func.defined()) { + !sinfo->ret->IsInstance()) { return false; } - static const EnvFunc fn = EnvFunc::Get("tvm.relax.struct_info.infer_by_sinfo_args"); - if (!fn.defined()) { - return false; - } - return sinfo->derive_func.value().same_as(fn); + return true; } bool AtTopLevelFunction(const IRDocsifier& d) { From 293a683a9fd558f6e7019a04e01c1283a2a1f614 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 18 Apr 2026 13:56:32 +0800 Subject: [PATCH 3/4] add test case: test_extern_func_with_struct_info --- .../relax/test_tvmscript_printer_relax.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index c50c1fcb254d..159b411b633e 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -98,6 +98,28 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): ) +def test_extern_func_with_struct_info(): + obj = IRModule( + { + "my_ext": relax.ExternFunc( + "my_ext", + relax.FuncStructInfo([], relax.TensorStructInfo(dtype="float32", ndim=2), purity=True), + ), + } + ) + _assert_print( + obj, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + my_ext: R.ExternFunc("my_ext", R.Callable((), R.Tensor(dtype="float32", ndim=2), True)) +""", + ) + + def test_nested_function(): @I.ir_module class NestedFunction: From 7ca6e627671c2379df842c9d10ac28eb9e2a9c78 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 18 Apr 2026 15:15:43 +0800 Subject: [PATCH 4/4] add test case: test_extern_func_with_struct_info_roundtrip --- .../python/relax/test_tvmscript_printer_relax.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 159b411b633e..cf3e28388eb5 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -115,11 +115,24 @@ def test_extern_func_with_struct_info(): @I.ir_module class Module: - my_ext: R.ExternFunc("my_ext", R.Callable((), R.Tensor(dtype="float32", ndim=2), True)) + my_ext = R.ExternFunc("my_ext", R.Callable((), R.Tensor(dtype="float32", ndim=2), True)) """, ) +def test_extern_func_with_struct_info_roundtrip(): + mod = IRModule( + { + "my_ext": relax.ExternFunc( + "my_ext", + relax.FuncStructInfo([], relax.TensorStructInfo(dtype="float32", ndim=2), purity=True), + ), + } + ) + roundtrip = tvm.script.from_source(mod.script(verbose_expr=True)) + tvm.ir.assert_structural_equal(mod, roundtrip) + + def test_nested_function(): @I.ir_module class NestedFunction: