diff --git a/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc b/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc index 5606301eb5b7..0e13e3a5a714 100644 --- a/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc @@ -164,6 +164,16 @@ inline InlineType AutoInlineNode::CheckInline(const s_tir::Schedule& sch, if (ann.value() == "disable") return InlineType::kNoInline; } // Last cond: Check inline into the consumers or the spatial producer + // Skip if: + // - root block (no parent scope). + // - root block is wrapped in a For loop, then parent is non-null, but there is still no enclosing SBlockNode. + const StmtSRefNode* p = block_sref->parent; + for (; p != nullptr; p = p->parent) { + if (p->stmt->IsInstance()) break; + } + if (p == nullptr) { + return InlineType::kNoInline; + } tirx::StmtSRef scope_block = s_tir::GetScopeRoot(sch->state(), block_sref, /*require_stage_pipeline=*/false); if (into_consumer) { diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py index 84231f3469bd..9bc1274cd6c7 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py @@ -533,5 +533,35 @@ def main(T_full: T.Buffer((), "float32")): assert_structural_equal(sch.mod, Full) +def test_no_inline_root_block(): + @tvm.script.ir_module + class MaxReduction: + @T.prim_func + def main( + data: T.Buffer((8, 8), "float32"), + data_red: T.Buffer((), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + with T.block("data_red"): + T.reads(data[0:8, 0:8]) + T.writes(data_red[()]) + with T.init(): + data_red[()] = T.float32(-3.4e38) + for i, j in T.grid(8, 8): + with T.block("update"): + T.reads(data_red[()], data[i, j]) + T.writes(data_red[()]) + data_red[()] = T.max(data_red[()], data[i, j]) + + target = Target("llvm") + (space,) = generate_design_space( + kind="llvm", + mod=MaxReduction, + target=target, + types=ms.schedule_rule.AutoInline, + ) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=MaxReduction) + + if __name__ == "__main__": tvm.testing.main()