Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,15 @@ class ForFrameNode : public TIRFrameNode {
* \param loop_body The loop body
* \return A stmt, the loop nest
*/
using FMakeForLoop =
ffi::TypedFunction<tvm::tir::Stmt(ffi::Array<tvm::tir::Var> loop_vars,
ffi::Array<Range> loop_extents, tvm::tir::Stmt loop_body)>;
using FMakeForLoop = ffi::TypedFunction<tvm::tir::Stmt(
ffi::Array<tvm::tir::Var> loop_vars, ffi::Array<Range> loop_extents,
ffi::Array<ffi::Optional<PrimExpr>> loop_steps, tvm::tir::Stmt loop_body)>;
/*! \brief The loop variable. */
ffi::Array<tvm::tir::Var> vars;
/*! \brief The domains of iteration. */
ffi::Array<Range> doms;
/*! \brief The optional steps of iteration. */
ffi::Array<ffi::Optional<PrimExpr>> steps;
/*! \brief The for loop generating function. */
FMakeForLoop f_make_for_loop;

Expand Down
16 changes: 12 additions & 4 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,37 +228,45 @@ ffi::Array<Var> Remap(ffi::String kinds, ffi::Array<PrimExpr> bindings,
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
* \param step The optional step value of iteration.
* \return The ForFrame.
*/
ForFrame Serial(PrimExpr start, PrimExpr stop,
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
ffi::Optional<PrimExpr> step = std::nullopt);
/*!
* \brief The parallel For statement.
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
* \param step The optional step value of iteration.
* \return The ForFrame.
*/
ForFrame Parallel(PrimExpr start, PrimExpr stop,
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
ffi::Optional<PrimExpr> step = std::nullopt);
/*!
* \brief The vectorized For statement.
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
* \param step The optional step value of iteration.
* \return The ForFrame.
*/
ForFrame Vectorized(PrimExpr start, PrimExpr stop,
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
ffi::Optional<PrimExpr> step = std::nullopt);
/*!
* \brief The unrolled For statement.
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
* \param step The optional step value of iteration.
* \return The ForFrame.
*/
ForFrame Unroll(PrimExpr start, PrimExpr stop,
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
ffi::Optional<PrimExpr> step = std::nullopt);
/*!
* \brief The thread-binding For statement.
* \param start The minimum value of iteration.
Expand Down
17 changes: 13 additions & 4 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ enum class ForKind : int {
*
* \code
*
* for (loop_var = min; loop_var < min + extent; ++loop_var) {
* for (loop_var = min; loop_var < min + extent; loop_var += step) {
* // body
* }
* \endcode
Expand Down Expand Up @@ -748,6 +748,10 @@ class ForNode : public StmtNode {
* and can be ignored in most passes.
*/
ffi::Map<ffi::String, ffi::Any> annotations;
/*!
* \brief The loop step. It is one if not specified.
*/
ffi::Optional<PrimExpr> step;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
Expand All @@ -758,8 +762,13 @@ class ForNode : public StmtNode {
.def_ro("kind", &ForNode::kind)
.def_ro("body", &ForNode::body)
.def_ro("thread_binding", &ForNode::thread_binding)
.def_ro("annotations", &ForNode::annotations);
.def_ro("annotations", &ForNode::annotations)
.def_ro("step", &ForNode::step);
}

/*! \brief Check it is a loop without nontrivial loop step. */
bool HasTrivialStep() const;

TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.For", ForNode, StmtNode);
};

Expand All @@ -771,8 +780,8 @@ class For : public Stmt {
public:
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
ffi::Optional<IterVar> thread_binding = std::nullopt,
ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
Span span = Span());
ffi::Map<ffi::String, ffi::Any> annotations = {},
ffi::Optional<PrimExpr> step = std::nullopt, Span span = Span());

TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(For, Stmt, ForNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);
Expand Down
44 changes: 36 additions & 8 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,11 @@ def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[L


def serial(
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None,
step: Optional[PrimExpr] = None,
) -> frame.ForFrame:
"""The serial For statement.

Expand All @@ -692,6 +696,9 @@ def serial(
annotations : Dict[str, Any]
The optional annotations of the For statement.

step : PrimExpr
The optional step value of iteration.

Returns
-------
res : frame.ForFrame
Expand All @@ -703,11 +710,15 @@ def serial(
start = IntImm(start.dtype, 0)
else:
start = 0
return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Serial(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member


def parallel(
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None,
step: Optional[PrimExpr] = None,
) -> frame.ForFrame:
"""The parallel For statement.

Expand All @@ -722,6 +733,9 @@ def parallel(
annotations : Dict[str, Any]
The optional annotations of the For statement.

step : PrimExpr
The optional step value of iteration.

Returns
-------
res : frame.ForFrame
Expand All @@ -733,11 +747,15 @@ def parallel(
start = IntImm(start.dtype, 0)
else:
start = 0
return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Parallel(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member


def vectorized(
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None,
step: Optional[PrimExpr] = None,
) -> frame.ForFrame:
"""The vectorized For statement.

Expand All @@ -752,6 +770,9 @@ def vectorized(
annotations : Dict[str, Any]
The optional annotations of the For statement.

step : PrimExpr
The optional step value of iteration.

Returns
-------
res : frame.ForFrame
Expand All @@ -763,11 +784,15 @@ def vectorized(
start = IntImm(start.dtype, 0)
else:
start = 0
return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Vectorized(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member


def unroll(
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None,
step: Optional[PrimExpr] = None,
) -> frame.ForFrame:
"""The unrolled For statement.

Expand All @@ -782,6 +807,9 @@ def unroll(
annotations : Dict[str, Any]
The optional annotations of the For statement.

step : PrimExpr
The optional step value of iteration.

Returns
-------
res : frame.ForFrame
Expand All @@ -793,7 +821,7 @@ def unroll(
start = IntImm(start.dtype, 0)
else:
start = 0
return _ffi_api.Unroll(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Unroll(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member


def thread_binding(
Expand Down
27 changes: 25 additions & 2 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import contextlib
from functools import partial
from typing import Any
from typing import Any, Dict, Optional

import tvm
from tvm.ir import GlobalVar, PrimType
Expand Down Expand Up @@ -168,6 +168,28 @@ def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: b
return default


def range_sugar(
start: PrimExpr,
stop: PrimExpr = None,
step: Optional[PrimExpr] = None,
*,
annotations: Dict[str, Any] = None,
) -> T.frame.ForFrame:
"""The sugar for python range builtin."""

# Since `tir.For` do not support reversed iteration semantic,
# the step must be checked to be positive integer when use range sugar
if step is not None:
try:
step = int(step)
if step <= 0:
raise ValueError(f"Only support positive step in range(), get {step}")
except TypeError: # pylint: disable=broad-except
raise ValueError(f"Only support literal step in range(), get {step}")

return T.serial(start, stop, annotations=annotations, step=step)


@dispatch.register(token="tir", type_name="For")
def visit_for(self: Parser, node: doc.For) -> None:
"""The for visiting method for tir.
Expand Down Expand Up @@ -379,7 +401,8 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
privacy = find_decorator_annotation(node, "private", default=False)
self.function_annotations = None
with self.var_table.with_frame():
self.var_table.add("range", T.serial)

self.var_table.add("range", range_sugar)
with T.prim_func(is_private=privacy):
T.func_name(node.name)
if node.returns is not None:
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def scope_attr(self, node, attr_key, value):
value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))

def for_range(self, begin, end, name="i", dtype=None, kind="serial"):
def for_range(self, begin, end, name="i", dtype=None, kind="serial", step=None):
"""Create a for iteration scope.

Parameters
Expand All @@ -223,6 +223,10 @@ def for_range(self, begin, end, name="i", dtype=None, kind="serial"):
kind : str, optional
The special tag on the for loop.

step : PrimExpr
The loop step. Default to none which
represent one.

Returns
-------
loop_scope : With.Scope of Var
Expand Down Expand Up @@ -275,7 +279,7 @@ def _exit_cb():
kind_id = _stmt.ForKind.UNROLLED
else:
raise ValueError("Unknown kind")
self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq()))
self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq(), step=step))

return WithScope(loop_var, _exit_cb)

Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
pass_ctx = tvm.transform.PassContext.current()
config = pass_ctx.config
passes = [
tir.transform.CanonicalizeLoop(),
tir.transform.LowerCrossThreadReduction(),
tir.transform.LowerInitBlock(),
tir.transform.PlanAndUpdateBufferAllocationLocation(),
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ class For(Stmt):
The thread this loop binds to. Only valid
if kind is ThreadBinding

step : PrimExpr
The loop step. Default to none which
represent one.

annotations: Optional[Mapping[str, Object]]
Additional annotation hints.

Expand All @@ -159,6 +163,7 @@ class For(Stmt):
body: Stmt
thread_binding: Optional[IterVar]
annotations: Mapping[str, Object]
step: Optional[PrimExpr]
span: Optional[Span]

def __init__(
Expand All @@ -170,6 +175,7 @@ def __init__(
body: Stmt,
thread_binding: Optional[IterVar] = None,
annotations: Optional[Mapping[str, Object]] = None,
step: Optional[PrimExpr] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
Expand All @@ -181,6 +187,7 @@ def __init__(
body,
thread_binding,
annotations,
step,
span,
)

Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,3 +1171,14 @@ def LowerVtcmAlloc():
The result pass
"""
return _ffi_api.LowerVtcmAlloc() # type: ignore


def CanonicalizeLoop():
"""Canonicalize the loop to start from zero and use trivial step

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.CanonicalizeLoop() # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@ class DistributedBufferCompactor : StmtExprMutator {
if (shard > 1) {
arith::Analyzer analyzer;
ICHECK(analyzer.CanProve(floormod(new_loop->extent, shard) == 0));
return For(new_loop->loop_var, new_loop->min, floordiv(new_loop->extent, shard),
new_loop->kind, new_loop->body, new_loop->thread_binding, new_loop->annotations);
new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard);
return new_loop;
}
}
return new_loop;
Expand Down
2 changes: 1 addition & 1 deletion src/script/ir_builder/tir/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void BlockInitFrameNode::ExitWithScope() {

void ForFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
AddToParent(this->f_make_for_loop(vars, doms, steps, AsStmt(stmts)));
}

void AssertFrameNode::ExitWithScope() {
Expand Down
Loading
Loading