Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/relax/transform/few_shot_tuning.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target&
ICHECK(runner.defined()) << "ValueError: The local runner is not defined!";
}
// create an IRModule
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar("main"), prim_func}}));
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>(
{{GlobalVar("main"), WithAttr(prim_func, tvm::attr::kGlobalSymbol, String("main"))}}));
// fetch the number of physical cores
static const auto* f_cpu_count = runtime::Registry::Get("meta_schedule.cpu_count");
ICHECK(f_cpu_count) << "ValueError: Cannot find the packed function \"meta_schedule._cpu_count\"";
Expand Down
14 changes: 8 additions & 6 deletions tests/python/relax/test_transform_few_shot_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def matmul(
B: T.Buffer((32, 32), "float16"),
C: T.Buffer((32, 32), "float16"),
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i, j, k in T.grid(32, 32, 32):
with T.block("C"):
Expand All @@ -54,7 +54,7 @@ def matmul(
class Softmax:
@T.prim_func
def softmax(rxplaceholder: T.Buffer((T.int64(8), T.int64(3456), T.int64(3456)), "float32"), T_softmax_norm: T.Buffer((T.int64(8), T.int64(3456), T.int64(3456)), "float32")):
T.func_attr({"global_symbol": "main", "op_pattern": 4, "tir.noalias": True})
T.func_attr({"op_pattern": 4, "tir.noalias": True})
# with T.block("root"):
T_softmax_maxelem = T.alloc_buffer((T.int64(8), T.int64(3456)), "float32")
T_softmax_exp = T.alloc_buffer((T.int64(8), T.int64(3456), T.int64(3456)), "float32")
Expand Down Expand Up @@ -93,7 +93,7 @@ def softmax(rxplaceholder: T.Buffer((T.int64(8), T.int64(3456), T.int64(3456)),
class Fused_Variance_Cast1:
@T.prim_func
def main(lv3: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), compute: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")):
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
T.func_attr({"tir.noalias": True})
# with T.block("root"):
rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
T_divide = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
Expand Down Expand Up @@ -152,7 +152,7 @@ def main(lv3: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), co
class Fuse_Mean_Cast1:
@T.prim_func
def main(lv: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), compute: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")):
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
T.func_attr({"tir.noalias": True})
# with T.block("root"):
rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
T_divide = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
Expand Down Expand Up @@ -181,7 +181,7 @@ def main(lv: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), com
class Module:
@T.prim_func
def main(lv26: T.Buffer((T.int64(1), T.int64(3456), T.int64(2560)), "float16"), T_multiply: T.Buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16")):
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
T.func_attr({"tir.noalias": True})
# with T.block("root"):
T_strided_slice_with_axes = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16")
T_divide = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16")
Expand Down Expand Up @@ -339,7 +339,9 @@ def _get_input_output_info(func: tvm.tir.PrimFunc) -> Tuple[List[np.ndarray], Tu
def _expected_results(
mod: tvm.ir.IRModule, inputs: List[np.ndarray], output_shape: Tuple, output_dtype: str
) -> np.ndarray:
rt_mod = tvm.build(mod, target="llvm")
func = _get_single_prim_func(mod)
func = func.with_attr("global_symbol", "main")
rt_mod = tvm.build(func, target="llvm")
data = [
tvm.nd.array(x)
for x in [
Expand Down