diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 827e4e032920..db5776890ab9 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -251,13 +251,15 @@ class ForFrameNode : public TIRFrameNode { * \param loop_body The loop body * \return A stmt, the loop nest */ - using FMakeForLoop = - ffi::TypedFunction loop_vars, - ffi::Array loop_extents, tvm::tir::Stmt loop_body)>; + using FMakeForLoop = ffi::TypedFunction loop_vars, ffi::Array loop_extents, + ffi::Array> loop_steps, tvm::tir::Stmt loop_body)>; /*! \brief The loop variable. */ ffi::Array vars; /*! \brief The domains of iteration. */ ffi::Array doms; + /*! \brief The optional steps of iteration. */ + ffi::Array> steps; /*! \brief The for loop generating function. */ FMakeForLoop f_make_for_loop; diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 24ce8fdf990a..07c7fe262bb3 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -228,37 +228,45 @@ ffi::Array Remap(ffi::String kinds, ffi::Array bindings, * \param start The minimum value of iteration. * \param stop The maximum value of iteration. * \param annotations The optional annotations of the For statement. + * \param step The optional step value of iteration. * \return The ForFrame. */ ForFrame Serial(PrimExpr start, PrimExpr stop, - ffi::Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt, + ffi::Optional step = std::nullopt); /*! * \brief The parallel For statement. * \param start The minimum value of iteration. * \param stop The maximum value of iteration. * \param annotations The optional annotations of the For statement. + * \param step The optional step value of iteration. * \return The ForFrame. */ ForFrame Parallel(PrimExpr start, PrimExpr stop, - ffi::Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt, + ffi::Optional step = std::nullopt); /*! * \brief The vectorized For statement. * \param start The minimum value of iteration. * \param stop The maximum value of iteration. * \param annotations The optional annotations of the For statement. + * \param step The optional step value of iteration. * \return The ForFrame. */ ForFrame Vectorized(PrimExpr start, PrimExpr stop, - ffi::Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt, + ffi::Optional step = std::nullopt); /*! * \brief The unrolled For statement. * \param start The minimum value of iteration. * \param stop The maximum value of iteration. * \param annotations The optional annotations of the For statement. + * \param step The optional step value of iteration. * \return The ForFrame. */ ForFrame Unroll(PrimExpr start, PrimExpr stop, - ffi::Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt, + ffi::Optional step = std::nullopt); /*! * \brief The thread-binding For statement. * \param start The minimum value of iteration. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 1b8041e36cc1..0831b84cf6fe 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -717,7 +717,7 @@ enum class ForKind : int { * * \code * - * for (loop_var = min; loop_var < min + extent; ++loop_var) { + * for (loop_var = min; loop_var < min + extent; loop_var += step) { * // body * } * \endcode @@ -748,6 +748,10 @@ class ForNode : public StmtNode { * and can be ignored in most passes. */ ffi::Map annotations; + /*! + * \brief The loop step. It is one if not specified. + */ + ffi::Optional step; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -758,8 +762,13 @@ class ForNode : public StmtNode { .def_ro("kind", &ForNode::kind) .def_ro("body", &ForNode::body) .def_ro("thread_binding", &ForNode::thread_binding) - .def_ro("annotations", &ForNode::annotations); + .def_ro("annotations", &ForNode::annotations) + .def_ro("step", &ForNode::step); } + + /*! \brief Check it is a loop without nontrivial loop step. */ + bool HasTrivialStep() const; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.For", ForNode, StmtNode); }; @@ -771,8 +780,8 @@ class For : public Stmt { public: TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ffi::Optional thread_binding = std::nullopt, - ffi::Map annotations = ffi::Map(), - Span span = Span()); + ffi::Map annotations = {}, + ffi::Optional step = std::nullopt, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(For, Stmt, ForNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode); diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 6d746d73b1be..31e48260f5c7 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -677,7 +677,11 @@ def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[L def serial( - start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None + start: PrimExpr, + stop: PrimExpr = None, + *, + annotations: Dict[str, Any] = None, + step: Optional[PrimExpr] = None, ) -> frame.ForFrame: """The serial For statement. @@ -692,6 +696,9 @@ def serial( annotations : Dict[str, Any] The optional annotations of the For statement. + step : PrimExpr + The optional step value of iteration. + Returns ------- res : frame.ForFrame @@ -703,11 +710,15 @@ def serial( start = IntImm(start.dtype, 0) else: start = 0 - return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Serial(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def parallel( - start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None + start: PrimExpr, + stop: PrimExpr = None, + *, + annotations: Dict[str, Any] = None, + step: Optional[PrimExpr] = None, ) -> frame.ForFrame: """The parallel For statement. @@ -722,6 +733,9 @@ def parallel( annotations : Dict[str, Any] The optional annotations of the For statement. + step : PrimExpr + The optional step value of iteration. + Returns ------- res : frame.ForFrame @@ -733,11 +747,15 @@ def parallel( start = IntImm(start.dtype, 0) else: start = 0 - return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Parallel(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def vectorized( - start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None + start: PrimExpr, + stop: PrimExpr = None, + *, + annotations: Dict[str, Any] = None, + step: Optional[PrimExpr] = None, ) -> frame.ForFrame: """The vectorized For statement. @@ -752,6 +770,9 @@ def vectorized( annotations : Dict[str, Any] The optional annotations of the For statement. + step : PrimExpr + The optional step value of iteration. + Returns ------- res : frame.ForFrame @@ -763,11 +784,15 @@ def vectorized( start = IntImm(start.dtype, 0) else: start = 0 - return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Vectorized(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def unroll( - start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None + start: PrimExpr, + stop: PrimExpr = None, + *, + annotations: Dict[str, Any] = None, + step: Optional[PrimExpr] = None, ) -> frame.ForFrame: """The unrolled For statement. @@ -782,6 +807,9 @@ def unroll( annotations : Dict[str, Any] The optional annotations of the For statement. + step : PrimExpr + The optional step value of iteration. + Returns ------- res : frame.ForFrame @@ -793,7 +821,7 @@ def unroll( start = IntImm(start.dtype, 0) else: start = 0 - return _ffi_api.Unroll(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Unroll(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def thread_binding( diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 85ab1982f384..f8cbc0b4f5bc 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -18,7 +18,7 @@ import contextlib from functools import partial -from typing import Any +from typing import Any, Dict, Optional import tvm from tvm.ir import GlobalVar, PrimType @@ -168,6 +168,28 @@ def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: b return default +def range_sugar( + start: PrimExpr, + stop: PrimExpr = None, + step: Optional[PrimExpr] = None, + *, + annotations: Dict[str, Any] = None, +) -> T.frame.ForFrame: + """The sugar for python range builtin.""" + + # Since `tir.For` do not support reversed iteration semantic, + # the step must be checked to be positive integer when use range sugar + if step is not None: + try: + step = int(step) + if step <= 0: + raise ValueError(f"Only support positive step in range(), get {step}") + except TypeError: # pylint: disable=broad-except + raise ValueError(f"Only support literal step in range(), get {step}") + + return T.serial(start, stop, annotations=annotations, step=step) + + @dispatch.register(token="tir", type_name="For") def visit_for(self: Parser, node: doc.For) -> None: """The for visiting method for tir. @@ -379,7 +401,8 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: privacy = find_decorator_annotation(node, "private", default=False) self.function_annotations = None with self.var_table.with_frame(): - self.var_table.add("range", T.serial) + + self.var_table.add("range", range_sugar) with T.prim_func(is_private=privacy): T.func_name(node.name) if node.returns is not None: diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index d6466b09224d..3b4a78e53e2f 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -202,7 +202,7 @@ def scope_attr(self, node, attr_key, value): value = op.max(1, value) self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x)) - def for_range(self, begin, end, name="i", dtype=None, kind="serial"): + def for_range(self, begin, end, name="i", dtype=None, kind="serial", step=None): """Create a for iteration scope. Parameters @@ -223,6 +223,10 @@ def for_range(self, begin, end, name="i", dtype=None, kind="serial"): kind : str, optional The special tag on the for loop. + step : PrimExpr + The loop step. Default to none which + represent one. + Returns ------- loop_scope : With.Scope of Var @@ -275,7 +279,7 @@ def _exit_cb(): kind_id = _stmt.ForKind.UNROLLED else: raise ValueError("Unknown kind") - self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq())) + self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq(), step=step)) return WithScope(loop_var, _exit_cb) diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index 22cec3033497..96ed9dfdbc96 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -31,6 +31,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I pass_ctx = tvm.transform.PassContext.current() config = pass_ctx.config passes = [ + tir.transform.CanonicalizeLoop(), tir.transform.LowerCrossThreadReduction(), tir.transform.LowerInitBlock(), tir.transform.PlanAndUpdateBufferAllocationLocation(), diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index bd90d5257495..448ace3ade63 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -145,6 +145,10 @@ class For(Stmt): The thread this loop binds to. Only valid if kind is ThreadBinding + step : PrimExpr + The loop step. Default to none which + represent one. + annotations: Optional[Mapping[str, Object]] Additional annotation hints. @@ -159,6 +163,7 @@ class For(Stmt): body: Stmt thread_binding: Optional[IterVar] annotations: Mapping[str, Object] + step: Optional[PrimExpr] span: Optional[Span] def __init__( @@ -170,6 +175,7 @@ def __init__( body: Stmt, thread_binding: Optional[IterVar] = None, annotations: Optional[Mapping[str, Object]] = None, + step: Optional[PrimExpr] = None, span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( @@ -181,6 +187,7 @@ def __init__( body, thread_binding, annotations, + step, span, ) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 39105f21a23c..88cf4720d3a6 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -1171,3 +1171,14 @@ def LowerVtcmAlloc(): The result pass """ return _ffi_api.LowerVtcmAlloc() # type: ignore + + +def CanonicalizeLoop(): + """Canonicalize the loop to start from zero and use trivial step + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.CanonicalizeLoop() # type: ignore diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index f83edb3e90c6..837f2f0a5dcb 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -330,8 +330,8 @@ class DistributedBufferCompactor : StmtExprMutator { if (shard > 1) { arith::Analyzer analyzer; ICHECK(analyzer.CanProve(floormod(new_loop->extent, shard) == 0)); - return For(new_loop->loop_var, new_loop->min, floordiv(new_loop->extent, shard), - new_loop->kind, new_loop->body, new_loop->thread_binding, new_loop->annotations); + new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard); + return new_loop; } } return new_loop; diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 94eef40f59be..7c10b6cdc8d1 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -123,7 +123,7 @@ void BlockInitFrameNode::ExitWithScope() { void ForFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts))); + AddToParent(this->f_make_for_loop(vars, doms, steps, AsStmt(stmts))); } void AssertFrameNode::ExitWithScope() { diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index b981b90bd81b..00f9c28475b4 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -362,19 +362,23 @@ ffi::Array Remap(ffi::String kinds, ffi::Array bindings, DataType #define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ ForFrame Method(PrimExpr start, PrimExpr stop, \ - ffi::Optional> annotations) { \ + ffi::Optional> annotations, \ + ffi::Optional step) { \ PrimExpr min = start; \ PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ ObjectPtr n = ffi::make_object(); \ int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; \ n->doms = {Range::FromMinExtent(min, extent)}; \ + n->steps = {step}; \ n->f_make_for_loop = [annotations](ffi::Array vars, ffi::Array doms, \ + ffi::Array> steps, \ tvm::tir::Stmt body) { \ ICHECK_EQ(vars.size(), 1); \ ICHECK_EQ(doms.size(), 1); \ + ICHECK_EQ(steps.size(), 1); \ return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, std::nullopt, \ - annotations.value_or(ffi::Map())); \ + annotations.value_or(ffi::Map()), steps[0]); \ }; \ return ForFrame(n); \ } @@ -396,13 +400,16 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, DataType dtype = DataType(min.dtype().code(), bits, 1); n->vars = {Var("v", dtype)}; n->doms = {Range::FromMinExtent(min, extent)}; + n->steps = {std::nullopt}; n->f_make_for_loop = [annotations, thread, dtype](ffi::Array vars, ffi::Array doms, + ffi::Array> steps, Stmt body) -> For { ICHECK_EQ(vars.size(), 1); ICHECK_EQ(doms.size(), 1); + ICHECK(steps.size() == 1 && (!steps[0].has_value() || is_one(*steps[0]))); IterVar iter_var(Range(nullptr), Var("iter", dtype), IterVarType::kThreadIndex, thread); return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, - annotations.value_or(ffi::Map())); + annotations.value_or(ffi::Map()), std::nullopt); }; return ForFrame(n); } @@ -412,19 +419,22 @@ ForFrame Grid(ffi::Array extents) { ObjectPtr n = ffi::make_object(); n->vars.reserve(extents.size()); n->doms.reserve(extents.size()); + n->steps.resize(extents.size()); for (const auto& extent : extents) { DataType dtype = extent.dtype(); n->vars.push_back(Var("v", extent.dtype())); n->doms.push_back(Range(make_const(dtype, 0), extent)); } - n->f_make_for_loop = [](ffi::Array vars, ffi::Array doms, Stmt body) -> Stmt { + n->f_make_for_loop = [](ffi::Array vars, ffi::Array doms, + ffi::Array> steps, Stmt body) -> Stmt { ICHECK_EQ(vars.size(), doms.size()); + ICHECK_EQ(vars.size(), steps.size()); int n = vars.size(); for (int i = n - 1; i >= 0; --i) { Range dom = doms[i]; Var var = vars[i]; body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body), - /*thread_binding=*/std::nullopt, /*annotations=*/{}); + /*thread_binding=*/std::nullopt, /*annotations=*/{}, /*step=*/steps[i]); } return body; }; diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index 742d23f69cdd..b2e091f38019 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -39,7 +39,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (l->kind != tir::ForKind::kSerial || // !tir::is_zero(l->min) || // !l->annotations.empty() || // - f_var_dep(l->extent)) { + !l->HasTrivialStep() || f_var_dep(l->extent)) { break; } grid.push_back(l); @@ -69,7 +69,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ffi::Optional max = std::nullopt; ffi::Optional annotations = std::nullopt; ffi::Optional thread = std::nullopt; - if (tir::is_zero(loop->min)) { + if (tir::is_zero(loop->min) && loop->HasTrivialStep()) { max = d->AsDoc(loop->extent, loop_p->Attr("extent")); } else { min = d->AsDoc(loop->min, loop_p->Attr("min")); @@ -78,10 +78,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (!loop->annotations.empty()) { annotations = d->AsDoc(loop->annotations, loop_p->Attr("annotations")); } + bool use_range_sugar = false; ExprDoc prefix{ffi::UnsafeInit()}; if (loop->kind == tir::ForKind::kSerial) { if (loop->annotations.empty()) { prefix = IdDoc("range"); + use_range_sugar = true; } else { prefix = TIR(d, "serial"); } @@ -115,6 +117,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kwargs_keys.push_back("annotations"); kwargs_values.push_back(annotations.value()); } + if (!loop->HasTrivialStep()) { + ExprDoc step = d->AsDoc(*loop->step, loop_p->Attr("step")); + if (use_range_sugar) { + args.push_back(step); + } else { + kwargs_keys.push_back("step"); + kwargs_values.push_back(step); + } + } ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values); AsDocBody(loop->body, loop_p->Attr("body"), (*f).get(), d); return ForDoc(lhs, rhs, (*f)->stmts); diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index d9ee9723216c..bc67cdad2fd3 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -1152,14 +1152,15 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { void CodeGenCPU::VisitStmt_(const ForNode* op) { EmitDebugLocation(op); - ICHECK(is_zero(op->min)); if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) { CodeGenLLVM::VisitStmt_(op); } else if (op->kind == ForKind::kParallel) { + ICHECK(is_zero(op->min)) << "Parallel launch require canonical loop with zero start index"; + ICHECK(op->HasTrivialStep()) << "Parallel launch require canonical loop with trivial loop step"; if (parallel_env_.penv == nullptr) { - CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body, - op->thread_binding, op->annotations), - 0, std::string("loop_parallel_") + op->loop_var->name_hint.c_str()); + auto copy_node = For(ffi::make_object(*op)); + CreateParallelLaunch(copy_node, 0, + std::string("loop_parallel_") + op->loop_var->name_hint.c_str()); } else { // already in parallel env. ICHECK(parallel_env_.task_id.defined()); @@ -1171,13 +1172,14 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { ICHECK(!parallel_env_.in_parallel_loop) << "Nested parallel loop is not supported by threadpool, try fuse them instead"; parallel_env_.in_parallel_loop = true; + PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent); if (parallel_env_.stride_pattern) { - CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task), - op->loop_var, op->body); + CreateSerialFor(MakeValue(task_id), MakeValue(end), MakeValue(num_task), op->loop_var, + op->body); } else { PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; PrimExpr begin = min(task_id * step, op->extent); - PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent); + end = min((task_id + make_const(t, 1)) * step, end); CreateSerialFor(MakeValue(begin), MakeValue(end), llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index bdb0c6b7389f..999e3a61eee8 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -2020,7 +2020,6 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { void CodeGenLLVM::VisitStmt_(const ForNode* op) { EmitDebugLocation(op); - ICHECK(is_zero(op->min)); analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); if (op->kind == ForKind::kUnrolled) { LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, " @@ -2028,8 +2027,11 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { } else { ICHECK(op->kind == ForKind::kSerial); } - CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), - llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); + PrimExpr step = op->step.value_or(make_const(op->extent->dtype, 1)); + PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent); + llvm::Value* begin_value = MakeValue(op->min); + llvm::Value* end_value = MakeValue(end); + CreateSerialFor(begin_value, end_value, MakeValue(step), op->loop_var, op->body); } void CodeGenLLVM::VisitStmt_(const WhileNode* op) { diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 8ebd41645aa2..52ad78166981 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -1120,13 +1120,21 @@ void CodeGenC::VisitStmt_(const AssertStmtNode* op) { } void CodeGenC::VisitStmt_(const ForNode* op) { - std::string extent = PrintExpr(op->extent); + std::string begin_str = PrintExpr(op->min); + PrimExpr end = is_zero(op->min) ? op->extent : arith::Analyzer().Simplify(op->min + op->extent); + std::string end_str = PrintExpr(end); + std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : ""; PrintIndent(); std::string vid = AllocVarID(op->loop_var.get()); - ICHECK(is_zero(op->min)); stream << "for ("; PrintType(op->loop_var.dtype(), stream); - stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n"; + stream << ' ' << vid << " = " << begin_str << "; " << vid << " < " << end_str << "; "; + if (step_str.empty()) { + stream << "++" << vid; + } else { + stream << vid << " += " << step_str; + } + stream << ") {\n"; int for_scope = BeginScope(); PrintStmt(op->body); this->EndScope(for_scope); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 9565eba5d4aa..a9cfad9ab6f5 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -319,7 +319,6 @@ std::string CodeGenCUDA::Finish() { } void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) { - ICHECK(is_const_int(op->min, 0)); if (op->kind == tir::ForKind::kUnrolled) { PrintIndent(); stream << "#pragma unroll\n"; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 330a54563fce..cf8176001a8a 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -667,13 +667,21 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) { } void CodeGenWebGPU::VisitStmt_(const ForNode* op) { - std::string extent = PrintExpr(op->extent); + std::string begin_str = PrintExpr(op->min); + PrimExpr end = is_zero(op->min) ? op->extent : arith::Analyzer().Simplify(op->min + op->extent); + std::string end_str = PrintExpr(end); + std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : ""; std::string vid = AllocVarID(op->loop_var.get()); - ICHECK(is_zero(op->min)); PrintIndent(); stream << "for (var " << vid << " : "; PrintType(op->loop_var.dtype(), stream); - stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n"; + stream << " = " << begin_str << "; " << vid << " < " << end_str << "; " << vid; + if (step_str.empty()) { + stream << "++"; + } else { + stream << " += " << step_str; + } + stream << ") {\n"; int for_scope = BeginScope(); PrintStmt(op->body); this->EndScope(for_scope); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index ddbc22d88a04..4500f4219417 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -672,10 +672,21 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { } void CodeGenSPIRV::VisitStmt_(const ForNode* op) { - ICHECK(is_zero(op->min)); analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); spirv::Value init_value = MakeValue(op->min); - spirv::Value extent_value = MakeValue(op->extent); + PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent); + spirv::Value end_value = MakeValue(end); + spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); + + // loop step + spirv::Value step; + if (op->HasTrivialStep()) { + step = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) + : builder_->UIntImm(loop_var.stype, 1); + } else { + step = MakeValue(tvm::cast(end->dtype, *op->step)); + } + // Must get init label after making value(to make sure they are correct) spirv::Label init_label = builder_->CurrentLabel(); spirv::Label head_label = builder_->NewLabel(); @@ -690,9 +701,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // Loop head builder_->StartLabel(head_label); - spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); loop_var.SetIncoming(0, init_value, init_label); - spirv::Value loop_cond = builder_->LT(loop_var, extent_value); + spirv::Value loop_cond = builder_->LT(loop_var, end_value); uint32_t control = (op->kind == ForKind::kUnrolled ? spv::LoopControlUnrollMask : spv::LoopControlMaskNone); builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control); @@ -707,9 +717,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // loop continue builder_->StartLabel(continue_label); - spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) - : builder_->UIntImm(loop_var.stype, 1); - spirv::Value next_value = builder_->Add(loop_var, one); + + spirv::Value next_value = builder_->Add(loop_var, step); loop_var.SetIncoming(1, next_value, builder_->CurrentLabel()); builder_->MakeInst(spv::OpBranch, head_label); // loop merge diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index d6dcae6540ba..393ac7ee57d0 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -41,8 +41,13 @@ Stmt DataTypeLegalizer::VisitStmt_(const ForNode* op) { ICHECK(op != nullptr) << "Expected type to be ForNode, but get " << s->GetTypeKey(); PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); - return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->kind, op->body, - op->thread_binding, op->annotations); + auto n = CopyOnWrite(op); + n->min = cast(var.dtype(), op->min); + n->extent = cast(var.dtype(), op->extent); + if (op->step.has_value()) { + n->step = cast(var.dtype(), *op->step); + } + return For(n); } Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) { diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index d33a01340b96..b6bca98d9179 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -132,7 +132,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, - ffi::Optional thread_binding, ffi::Map annotations, Span span) { + ffi::Optional thread_binding, ffi::Map annotations, + ffi::Optional step, Span span) { ICHECK(loop_var.defined()); ICHECK(min.defined()); ICHECK(extent.defined()); @@ -148,8 +149,8 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, require_scalar_int_dtype(min, "min"); require_scalar_int_dtype(extent, "extent"); - // When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them - // without raising errors. + // When extent, min or step is an IntImm but has narrower dtype than loop_var + // we directly promote them without raising errors. auto try_promote_imm_dtype = [&](const PrimExpr& e) { ICHECK(e.dtype().bits() <= loop_var.dtype().bits()) << " Loop variable's dtype (" << loop_var.dtype() @@ -168,6 +169,12 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype(); + if (step.has_value()) { + require_scalar_int_dtype(*step, "step"); + step = try_promote_imm_dtype(*step); + ICHECK(loop_var.dtype() == (*step).dtype()) << loop_var.dtype() << " vs " << (*step).dtype(); + } + ObjectPtr node = ffi::make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); @@ -176,19 +183,22 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, node->body = std::move(body); node->thread_binding = std::move(thread_binding); node->annotations = std::move(annotations); + node->step = std::move(step); node->span = std::move(span); data_ = std::move(node); } +bool ForNode::HasTrivialStep() const { return !step.has_value() || is_one(*step); } + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "tir.For", [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, - ffi::Optional thread_binding, - ffi::Optional> annotations, Span span) { - return For(loop_var, min, extent, static_cast(kind), body, thread_binding, - annotations.value_or(ffi::Map()), span); - }); + refl::GlobalDef().def("tir.For", [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, + Stmt body, ffi::Optional thread_binding, + ffi::Optional> annotations, + ffi::Optional step, Span span) { + return For(loop_var, min, extent, static_cast(kind), body, thread_binding, + annotations.value_or(ffi::Map()), step, span); + }); } std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*) diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 80c787b11400..e6666cc63816 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -46,6 +46,9 @@ void StmtVisitor::VisitStmt_(const AttrStmtNode* op) { void StmtVisitor::VisitStmt_(const ForNode* op) { this->VisitExpr(op->min); this->VisitExpr(op->extent); + if (op->step.has_value()) { + this->VisitExpr(*op->step); + } this->VisitStmt(op->body); } @@ -260,13 +263,19 @@ Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { Stmt StmtMutator::VisitStmt_(const ForNode* op) { PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); + ffi::Optional step{std::nullopt}; + if (op->step.has_value()) { + step = this->VisitExpr(*op->step); + } Stmt body = this->VisitStmt(op->body); - if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { + if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body) && + step.same_as(op->step)) { return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->min = std::move(min); n->extent = std::move(extent); + n->step = std::move(step); n->body = std::move(body); return Stmt(n); } diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index fbc569ece689..2ae32ea66a6a 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -703,7 +703,7 @@ class BlockizeRewriter : public StmtMutator { Stmt VisitStmt_(const ForNode* loop) final { if (loop == lca_->stmt) { return For(loop->loop_var, loop->min, loop->extent, loop->kind, RewriteSeq(loop->body), - loop->thread_binding, loop->annotations, loop->span); + loop->thread_binding, loop->annotations, loop->step, loop->span); } return StmtMutator::VisitStmt_(loop); } diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index 5499ab9c58d0..7e61fd4eb20a 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -343,7 +343,7 @@ static std::pair CreateInBoundBlock(const BlockRealizeNode* PrimExpr min = it == new_loop_ranges.end() ? loop->min : (*it).second->min; PrimExpr extent = it == new_loop_ranges.end() ? loop->extent : (*it).second->extent; nest_stmt_root = For(loop->loop_var, min, extent, loop->kind, nest_stmt_root, - loop->thread_binding, loop->annotations, loop->span); + loop->thread_binding, loop->annotations, loop->step, loop->span); if (loop.same_as(highest_pos_inclusive)) { break; } diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index b2c64e65e568..3cd364b0fd2b 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -1137,8 +1137,8 @@ void Reorder(ScheduleState self, const ffi::Array& ordered_loop_srefs) StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { if (sref->stmt->IsInstance()) { - For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, - ffi::GetRef(sref->stmt)); + For new_loop = + For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, ffi::GetRef(sref->stmt)); self->Replace(sref, new_loop, {}); return self->stmt2ref.at(new_loop.get()); } diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 49dc31e6f6e5..0629757a13d8 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -268,7 +268,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, std::unordered_map loop_var_map; Stmt body = BlockRealize(init_realize); for (int i : chosen_loops) { - const ForNode* old_loop = TVM_SREF_TO_FOR(loops[i]); + For old_loop = ffi::GetRef(TVM_SREF_TO_FOR(loops[i])); // Create a new equivalent to the chosen loop Var old_loop_var = old_loop->loop_var; Var new_loop_var = old_loop_var.copy_with_suffix("_init"); @@ -280,12 +280,11 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, thread_binding.CopyOnWrite()->var = new_var; opt_thread_binding = thread_binding; } - body = For(/*loop_var=*/new_loop_var, - /*min=*/old_loop->min, - /*extent=*/old_loop->extent, - /*kind=*/old_loop->kind, - /*body=*/body, - /*thread_binding=*/opt_thread_binding); + auto new_loop = old_loop.CopyOnWrite(); + new_loop->loop_var = new_loop_var; + new_loop->thread_binding = opt_thread_binding; + new_loop->body = body; + body = ffi::GetRef(new_loop); } body = Substitute(body, loop_var_map); // Step 6. Mutate IR diff --git a/src/tir/transforms/canonicalize_loop.cc b/src/tir/transforms/canonicalize_loop.cc new file mode 100644 index 000000000000..93511bf84bb2 --- /dev/null +++ b/src/tir/transforms/canonicalize_loop.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/transforms/canonicalize_loop.cc + * \brief Canonicalize all loops to start from zero and step one. + */ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +class LoopCanonicalizer : public StmtExprMutator { + public: + LoopCanonicalizer() = default; + + private: + Stmt VisitStmt_(const ForNode* op) final { + if (is_zero(op->min) && op->HasTrivialStep()) { + return StmtExprMutator::VisitStmt_(op); + } + arith::Analyzer analyzer; + const auto* loop_var = op->loop_var.get(); + PrimExpr step = op->step.value_or(make_const(loop_var->dtype, 1)); + + // report warning for negative step, since it would be a forever loop + if (!analyzer.CanProveGreaterEqual(step, 1)) { + // TODO(tvm): prove dynamic shaped step + LOG(FATAL) << "Loop step for " << op->loop_var << " may not be positive: " << step; + } + + new_iter_info_[loop_var] = std::make_pair(step, op->min); + auto n = CopyOnWrite(op); + n->body = VisitStmt(op->body); + n->min = make_zero(loop_var->dtype); + n->extent = analyzer.Simplify(ceildiv(op->extent, step)); + n->step = std::nullopt; + new_iter_info_.erase(loop_var); + return For(n); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = new_iter_info_.find(op); + if (it != new_iter_info_.end()) { + const auto& [stride, offset] = it->second; + return ffi::GetRef(op) * stride + offset; + } + return ffi::GetRef(op); + } + + /*! \brief Map iter variable `x` to `x * stride + offset`. */ + std::unordered_map> new_iter_info_; +}; + +PrimFunc CanonicalizeLoop(PrimFunc func) { + PrimFuncNode* fptr = func.CopyOnWrite(); + fptr->body = LoopCanonicalizer()(func->body); + return func; +} + +namespace transform { + +Pass CanonicalizeLoop() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return CanonicalizeLoop(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.CanonicalizeLoop", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.CanonicalizeLoop", CanonicalizeLoop); +} + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index dfeb7fe2e219..9b9619fae937 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -602,7 +602,7 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const ForNode* op) { // Otherwise return a for node built with the new `min_new`, `extent_new` and `body_new` // that have just been obtained return For(op->loop_var, min_new, extent_new, op->kind, body_new, op->thread_binding, - op->annotations, op->span); + op->annotations, op->step, op->span); } } diff --git a/src/tir/transforms/convert_for_loops_serial.cc b/src/tir/transforms/convert_for_loops_serial.cc index a8b30ebf9101..691d8b885c59 100644 --- a/src/tir/transforms/convert_for_loops_serial.cc +++ b/src/tir/transforms/convert_for_loops_serial.cc @@ -43,7 +43,7 @@ class ForLoopSerialConverter : public StmtExprMutator { Stmt ForLoopSerialConverter::VisitStmt_(const ForNode* op) { if (op->kind == ForKind::kParallel) { return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body, op->thread_binding, - op->annotations, op->span); + op->annotations, op->step, op->span); } return StmtExprMutator::VisitStmt_(op); } diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index af1b7c8bdfa5..f4258fc479d6 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -943,7 +943,7 @@ class PipelineRewriter : public StmtExprMutator { if (!is_unit_loop) { new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop), - std::nullopt, preserved_annotations_); + std::nullopt, preserved_annotations_, std::nullopt); } // Update producer heads in the global async states. diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index dba13cfbbcf1..8bcb2077c677 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -362,9 +362,9 @@ class IRConvertSSA final : public StmtExprMutator { if (defined_.count(v.get())) { ScopedRedefine redefine(this, v); Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - return For(redefine.new_var, op->min, op->extent, op->kind, op->body, op->thread_binding, - op->annotations); + auto n = ffi::make_object(*stmt.as()); + n->loop_var = redefine.new_var; + return For(n); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/lift_thread_binding.cc b/src/tir/transforms/lift_thread_binding.cc index 2dffc11b7257..45bbf4af52de 100644 --- a/src/tir/transforms/lift_thread_binding.cc +++ b/src/tir/transforms/lift_thread_binding.cc @@ -133,7 +133,7 @@ class ThreadBindingLifter : public StmtExprMutator { ForKind::kThreadBinding, std::move(body), IterVar(Range(nullptr), Var(iter_var->thread_tag, iter_var->var->dtype), kThreadIndex, iter_var->thread_tag), - annotation); + annotation, std::nullopt); } } if (is_kernel_root) { diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index e644c387cf5a..fd9bd2d6531c 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -760,14 +760,18 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) { const ForNode* for_node = static_cast(node); ICHECK(for_node); + if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) && !no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) { // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { ICHECK(for_node->kind != ForKind::kThreadBinding); - return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->kind, body, - for_node->thread_binding, for_node->annotations); + auto new_loop = ffi::make_object(*for_node); + new_loop->min = IntImm(for_node->min.dtype(), 0); + new_loop->extent = extent; + new_loop->body = body; + return For(new_loop); } } diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 25e8734ff1c6..2f7ac3ddb1c0 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -878,7 +878,9 @@ class CrossThreadReductionTransformer : public StmtMutator { /*body=*/body, // /*thread_binding=*/ IterVar(NullValue(), Var("", loop_vars[i]->dtype), IterVarType::kThreadIndex, - "threadIdx." + dim_index)); + "threadIdx." + dim_index), + /*annotations=*/{}, + /*step=*/std::nullopt); } return body; } diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index 2e53e89667cc..c0363dd8982f 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -111,7 +111,7 @@ class OpaqueBlockLower : public StmtExprMutator { } else { // Case 3. An ordinary loop body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body), - std::nullopt, new_annotations); + std::nullopt, new_annotations, op->step); } // Step 5. Insert nested attrs for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) { diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc index 094f48e321f6..0d5b27044232 100644 --- a/src/tir/transforms/memhammer_coalesce.cc +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -128,7 +128,8 @@ Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { body = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, std::move(body)); for (int i = n - 2; i >= 1; i--) { body = For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding, std::move(body), - IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1])); + IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1]), + {}, std::nullopt); } return For(new_loop_vars[0], 0, factors[0], ForKind::kSerial, std::move(body)); } diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc index e16c51877188..e69ac30366b1 100644 --- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -70,8 +70,9 @@ std::pair> TileWmmaBlock(Stmt stmt) { } For compute_location = Downcast(body); for (int i = n - 3; i >= 0; i--) { - body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, std::move(body), - loops[i]->thread_binding, loops[i]->annotations); + auto new_loop = ffi::GetRef(loops[i]); + new_loop.CopyOnWrite()->body = std::move(body); + body = new_loop; } return {body, compute_location}; } @@ -187,8 +188,9 @@ Stmt RewriteWmmaLoad(Stmt stmt) { }, /*annotations=*/{})); for (int i = n - 3; i >= 0; i--) { - wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, - std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + auto new_loop = ffi::GetRef(loops[i]); + new_loop.CopyOnWrite()->body = std::move(wmma_body); + wmma_body = new_loop; } return wmma_body; } @@ -290,8 +292,9 @@ Stmt RewriteWmmaStore(Stmt stmt) { }, /*annotations=*/{})); for (int i = n - 3; i >= 0; i--) { - wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, - std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + auto new_loop = ffi::GetRef(loops[i]); + new_loop.CopyOnWrite()->body = std::move(wmma_body); + wmma_body = new_loop; } return wmma_body; } @@ -395,8 +398,9 @@ std::pair> TileMmaToGlobalBlock(Stmt stmt) { } For compute_location = Downcast(body); for (int i = n - 3; i >= 0; i--) { - body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, std::move(body), - loops[i]->thread_binding, loops[i]->annotations); + auto new_loop = ffi::GetRef(loops[i]); + new_loop.CopyOnWrite()->body = std::move(body); + body = new_loop; } return {body, compute_location}; } @@ -484,21 +488,21 @@ Stmt RewriteMmaStore(Stmt stmt) { /*reads=*/{BufferRegion(src_buffer, read_region)}, /*writes=*/{BufferRegion(tgt_buffer, write_region)}, /*name_hint=*/"mma_store", - AttrStmt(/*node=*/IterVar( - /*dom=*/Range::FromMinExtent(0, 32), - /*var=*/tx, - /*iter_type=*/IterVarType::kThreadIndex, - /*thread_tag=*/"threadIdx.x"), - /*attr_key=*/"thread_extent", - /*value=*/Integer(32), - /*body=*/ - For(vec, 0, 2, ForKind::kVectorized, - /*body=*/ - BufferStore(new_tgt_buffer, - BufferLoad(new_src_buffer, - {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}), - {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}), - /*annotations=*/{})), + AttrStmt( + /*node=*/IterVar( + /*dom=*/Range::FromMinExtent(0, 32), + /*var=*/tx, + /*iter_type=*/IterVarType::kThreadIndex, + /*thread_tag=*/"threadIdx.x"), + /*attr_key=*/"thread_extent", + /*value=*/Integer(32), + /*body=*/ + For(vec, 0, 2, ForKind::kVectorized, + /*body=*/ + BufferStore( + new_tgt_buffer, + BufferLoad(new_src_buffer, {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}), + {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}))), /*init=*/std::nullopt, /*alloc_buffers=*/{}, /*match_buffers=*/ @@ -510,8 +514,9 @@ Stmt RewriteMmaStore(Stmt stmt) { // Step 3.4. wrap outer loops for (int i = n - 3; i >= 0; i--) { - mma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, - std::move(mma_body), loops[i]->thread_binding, loops[i]->annotations); + auto new_loop = ffi::GetRef(loops[i]); + new_loop.CopyOnWrite()->body = std::move(mma_body); + mma_body = new_loop; } return mma_body; } diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 4af12c69a3b8..830364788c5e 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -510,7 +510,7 @@ class StoragePlanRewriter : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body), - op->thread_binding, op->annotations); + op->thread_binding, op->annotations, op->step); } else { return StmtExprMutator::VisitStmt_(op); } diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index fa1e221459c0..502acd5a467e 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -79,7 +79,8 @@ class ThreadBindingUnifier : public StmtExprMutator { /*extent=*/IntImm(dtype, 1), // /*kind=*/ForKind::kSerial, stmt, // /*thread_binding=*/std::nullopt, // - /*annotation=*/std::move(annotations)); + /*annotation=*/std::move(annotations), + /*step=*/std::nullopt); } } @@ -155,7 +156,8 @@ class ThreadBindingUnifier : public StmtExprMutator { result = For(thread_binding->var, thread_binding->dom->min, thread_binding->dom->extent, ForKind::kThreadBinding, result, IterVar(NullValue(), Var(""), IterVarType::kThreadIndex, - thread_binding->thread_tag)); + thread_binding->thread_tag), + {}, std::nullopt); launch_threads_.pop_back(); } return result; diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index d1269634ab4b..74abea57ba97 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -156,8 +156,9 @@ class LoopUnroller : public StmtExprMutator { } else { if (auto_unroll) { if (op->kind != ForKind::kUnrolled) { - return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body, - op->thread_binding, op->annotations); + auto n = CopyOnWrite(op); + n->kind = ForKind::kUnrolled; + return For(n); } } return stmt; diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 857f0b4cea99..068903baa814 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -752,8 +752,10 @@ class Vectorizer : public StmtMutator, public ExprFunctorextent) && body.same_as(op->body)) { return ffi::GetRef(op); } else { - return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding, - op->annotations); + auto n = CopyOnWrite(op); + n->extent = extent; + n->body = body; + return For(n); } } // IfThenElse diff --git a/tests/python/codegen/test_target_codegen.py b/tests/python/codegen/test_target_codegen.py index 3332d015a818..7530786a38d7 100644 --- a/tests/python/codegen/test_target_codegen.py +++ b/tests/python/codegen/test_target_codegen.py @@ -16,7 +16,7 @@ # under the License. import pytest - +import numpy as np import tvm from tvm.script import tir as T @@ -88,5 +88,47 @@ def func(a: T.handle, b: T.handle): tvm.compile(func) +@tvm.testing.parametrize_targets("c", "llvm") +def test_codegen_loop_step(target): + @T.prim_func + def test_loop_step( + A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + C: T.Buffer((1024,), "float32"), + ): + for i in T.serial(3, 1024, step=96): + C[i] = A[i] + B[i] + + with tvm.transform.PassContext(disabled_pass=["tir.CanonicalizeLoop"]): + lib = tvm.compile(test_loop_step, target=target) + + src = lib.mod.inspect_source() + if target == "c": + assert src.find("for (int32_t i = 3; i < 1024; i += 96)") >= 0 + + dev = tvm.device(target, 0) + a_np = np.random.rand(1024).astype("float32") + b_np = np.random.rand(1024).astype("float32") + c_np = np.zeros(1024, dtype="float32") + a_tvm = tvm.runtime.tensor(a_np, dev) + b_tvm = tvm.runtime.tensor(b_np, dev) + c_tvm = tvm.runtime.tensor(c_np, dev) + + lib(a_tvm, b_tvm, c_tvm) + + c_result = c_tvm.numpy() + + # Check that the loop executes at positions 3, 99, 195, 291, 387, 483, 579, 675, 771, 867, 963 + for i in range(3, 1024, 96): + np.testing.assert_allclose(c_result[i], a_np[i] + b_np[i], rtol=1e-5) + + # Assert non-touched positions remain zero + for i in range(0, 3): + assert c_result[i] == 0.0 + for i in range(4, 1024): + if (i - 3) % 96 != 0: + assert c_result[i] == 0.0 + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 0841d0f54562..1b31e64414b1 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -877,5 +877,37 @@ def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): assert "return;" in cuda_code +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_cuda_loop_step(): + @T.prim_func + def cuda_loop_step( + A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + C: T.Buffer((1024,), "float32"), + ): + # Each thread computes a strided subset of the i loop: start = tx*3, step = 96 (3 * 32 threads) + for bx in T.thread_binding(1, "blockIdx.x"): + for tx in T.thread_binding(96, "threadIdx.x"): + for i in T.serial(tx, 1024, step=96): + C[i] = A[i] + B[i] + + target = tvm.target.Target({"kind": "cuda"}) + with tvm.transform.PassContext(disabled_pass=["tir.CanonicalizeLoop"]): + lib = tvm.compile(cuda_loop_step, target=target) + + cuda_src = lib.mod.imports[0].inspect_source() + assert "i += 96" in cuda_src + dev = tvm.cuda(0) + a_np = np.random.uniform(1, 100, (1024,)).astype("float32") + b_np = np.random.uniform(1, 100, (1024,)).astype("float32") + c_np = np.zeros((1024,), dtype="float32") + a_nd = tvm.runtime.tensor(a_np, dev) + b_nd = tvm.runtime.tensor(b_np, dev) + c_nd = tvm.runtime.tensor(c_np, dev) + lib["main"](a_nd, b_nd, c_nd) + tvm.testing.assert_allclose(c_nd.numpy(), a_np + b_np) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index 5e1d25e48b0d..d546d2e13c7b 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -134,6 +134,7 @@ def test_basic(): def test_stmt(): x = tvm.tir.Evaluate(0) tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.SERIAL, x) + tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.UNROLLED, x, step=2) def test_dir(): diff --git a/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py b/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py new file mode 100644 index 000000000000..6f6d88137c20 --- /dev/null +++ b/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +from tvm import tir +from tvm.script import tir as T + + +def test_canonicalize_loop(): + @T.prim_func + def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + T.func_attr({"global_symbol": "main"}) + for i in range(1, 128, 5): + B[i] = A[i] + 1.0 + + @T.prim_func + def expected(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 26): + B[i * 5 + 1] = A[i * 5 + 1] + 1.0 + + mod = tvm.IRModule.from_expr(before) + mod = tir.transform.CanonicalizeLoop()(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_canonicalize_nested_loop(): + @T.prim_func + def before(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + T.func_attr({"global_symbol": "main"}) + for i in range(1, 128, 5): + for j in range(2, 128, 3): + B[i, j] = A[i, j] + 1.0 + + @T.prim_func + def expected(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 26): + for j in T.serial(0, 42): + B[i * 5 + 1, j * 3 + 2] = A[i * 5 + 1, j * 3 + 2] + 1.0 + + mod = tvm.IRModule.from_expr(before) + mod = tir.transform.CanonicalizeLoop()(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_canonicalize_negative_step(): + @T.prim_func + def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 127, step=-3): + B[i] = A[i] + 1.0 + + mod = tvm.IRModule.from_expr(before) + with pytest.raises(tvm.error.InternalError): + mod = tir.transform.CanonicalizeLoop()(mod) + + +def test_canonicalize_dynamic_step(): + """Currently we report error for dynamic step since we could not prove it is positive""" + + @T.prim_func + def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"], step: T.int32): + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 128, step=step): + B[i] = A[i] + 1.0 + + mod = tvm.IRModule.from_expr(before) + with pytest.raises(tvm.error.InternalError): + mod = tir.transform.CanonicalizeLoop()(mod) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index f1569be5b1f4..3b84e919c8bd 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -327,6 +327,32 @@ def non_starred(a: T.handle, b: T.handle): tvm.ir.assert_structural_equal(starred, non_starred) +def test_tir_loop_steps(): + N = T.Var("N", "int32") + + @T.prim_func(private=True) + def loop_with_steps( + A: T.Buffer((N,)), B: T.Buffer((N,)), C: T.Buffer((N,)), tid: T.int32, v: T.int32 + ): + for i in T.serial(tid, N, step=2): + C[i] = A[i] + B[i] + for i in T.unroll(tid, N, step=3): + C[i] = A[i] + B[i] + for i in T.vectorized(tid, N, step=4): + C[i] = A[i] + B[i] + for i in T.parallel(tid, N, step=5): + C[i] = A[i] + B[i] + for i in T.serial(tid, N, step=v): + C[i] = A[i] + B[i] + + stmts = loop_with_steps.body.seq + assert stmts[0].step == 2 + assert stmts[1].step == 3 + assert stmts[2].step == 4 + assert stmts[3].step == 5 + assert stmts[4].step.name == "v" + + def test_tir_empty_tuple_index(): @T.macro def bar(val): diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 1954ca773f14..b3d459b2e67f 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -4018,6 +4018,25 @@ def func(In: T.Buffer((1,), "int32"), Out: T.Buffer((2,), "int32")): return func +def func_with_loop_steps(): + @T.prim_func + def func( + A: T.Buffer((1024,)), B: T.Buffer((1024,)), C: T.Buffer((1024,)), tid: T.int32, v: T.int32 + ): + for i in T.serial(tid, 1024, step=2): + C[i] = A[i] + B[i] + for i in T.unroll(tid, 1024, step=3): + C[i] = A[i] + B[i] + for i in T.vectorized(tid, 1024, step=4): + C[i] = A[i] + B[i] + for i in T.parallel(tid, 1024, step=5): + C[i] = A[i] + B[i] + for i in range(tid, 1024, 6): + C[i] = A[i] + B[i] + + return func + + def op_of_literal(): op_list = [ (T.exp, 0), @@ -4237,6 +4256,7 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): return_zero_private_with_attr, func_attr_with_list, func_with_loop_jumps, + func_with_loop_steps, *op_of_literal(), *relax_match_cast_struct_info_proxy(), relax_symbolic_size_var,