From 415617efc9217c98784f37e55f58445518d6cf6b Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 13 Jun 2023 10:58:05 -0700 Subject: [PATCH] Fix global symbol issue. --- src/relax/transform/few_shot_tuning.cc | 3 ++- .../python/relax/test_transform_few_shot_tuning.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/relax/transform/few_shot_tuning.cc b/src/relax/transform/few_shot_tuning.cc index c95b099321ce..4ad5e2367524 100644 --- a/src/relax/transform/few_shot_tuning.cc +++ b/src/relax/transform/few_shot_tuning.cc @@ -45,7 +45,8 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& ICHECK(runner.defined()) << "ValueError: The local runner is not defined!"; } // create an IRModule - IRModule mod = IRModule(Map({{GlobalVar("main"), prim_func}})); + IRModule mod = IRModule(Map( + {{GlobalVar("main"), WithAttr(prim_func, tvm::attr::kGlobalSymbol, String("main"))}})); // fetch the number of physical cores static const auto* f_cpu_count = runtime::Registry::Get("meta_schedule.cpu_count"); ICHECK(f_cpu_count) << "ValueError: Cannot find the packed function \"meta_schedule._cpu_count\""; diff --git a/tests/python/relax/test_transform_few_shot_tuning.py b/tests/python/relax/test_transform_few_shot_tuning.py index be26f8dad53c..0b4e2e08c5ff 100644 --- a/tests/python/relax/test_transform_few_shot_tuning.py +++ b/tests/python/relax/test_transform_few_shot_tuning.py @@ -39,7 +39,7 @@ def matmul( B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) # with T.block("root"): for i, j, k in T.grid(32, 32, 32): with T.block("C"): @@ -54,7 +54,7 @@ def matmul( class Softmax: @T.prim_func def softmax(rxplaceholder: T.Buffer((T.int64(8), T.int64(3456), T.int64(3456)), "float32"), T_softmax_norm: T.Buffer((T.int64(8), T.int64(3456), T.int64(3456)), "float32")): - T.func_attr({"global_symbol": "main", "op_pattern": 4, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tir.noalias": True}) # with T.block("root"): T_softmax_maxelem = T.alloc_buffer((T.int64(8), T.int64(3456)), "float32") T_softmax_exp = T.alloc_buffer((T.int64(8), T.int64(3456), T.int64(3456)), "float32") @@ -93,7 +93,7 @@ def softmax(rxplaceholder: T.Buffer((T.int64(8), T.int64(3456), T.int64(3456)), class Fused_Variance_Cast1: @T.prim_func def main(lv3: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), compute: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")): - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tir.noalias": True}) # with T.block("root"): rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) T_divide = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) @@ -152,7 +152,7 @@ def main(lv3: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), co class Fuse_Mean_Cast1: @T.prim_func def main(lv: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), compute: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")): - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tir.noalias": True}) # with T.block("root"): rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) T_divide = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) @@ -181,7 +181,7 @@ def main(lv: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), com class Module: @T.prim_func def main(lv26: T.Buffer((T.int64(1), T.int64(3456), T.int64(2560)), "float16"), T_multiply: T.Buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16")): - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tir.noalias": True}) # with T.block("root"): T_strided_slice_with_axes = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16") T_divide = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16") @@ -339,7 +339,9 @@ def _get_input_output_info(func: tvm.tir.PrimFunc) -> Tuple[List[np.ndarray], Tu def _expected_results( mod: tvm.ir.IRModule, inputs: List[np.ndarray], output_shape: Tuple, output_dtype: str ) -> np.ndarray: - rt_mod = tvm.build(mod, target="llvm") + func = _get_single_prim_func(mod) + func = func.with_attr("global_symbol", "main") + rt_mod = tvm.build(func, target="llvm") data = [ tvm.nd.array(x) for x in [