Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Schedule Primitive: Add-Unit-Loop #11575

Merged
merged 2 commits into from
Jun 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,18 @@ class ScheduleNode : public runtime::Object {
* \param ordered_loop_rvs The loops in the new order
*/
virtual void Reorder(const Array<LoopRV>& ordered_loop_rvs) = 0;
/*!
* \brief Create a new unit loop on top of the specific block.
* \param block_rv The block above which the new loop is created
* \return The new loop created
*/
virtual LoopRV AddUnitLoop(const BlockRV& block_rv) = 0;
/*!
* \brief Create a new unit loop on top of the specific loop.
* \param loop_rv The loop above which the new loop is created
* \return The new loop created
*/
virtual LoopRV AddUnitLoop(const LoopRV& loop_rv) = 0;
/******** Schedule: Manipulate ForKind ********/
/*!
* \brief Parallelize the input loop. It requires:
Expand Down
64 changes: 60 additions & 4 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@
# specific language governing permissions and limitations
# under the License.
"""The TensorIR schedule class"""
from typing import Callable, Dict, List, Optional, Union, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, Buffer
from ..function import IndexMap
from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc

from ..function import IndexMap
from . import _ffi_api
from ._type_checker import type_checked
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
from .trace import Trace
from ._type_checker import type_checked


@register_error
Expand Down Expand Up @@ -685,6 +685,62 @@ def after_reorder(a: T.handle, b: T.handle) -> None:
"""
_ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member

@type_checked
def add_unit_loop(self, block_or_loop: Union[LoopRV, BlockRV]) -> LoopRV:
"""Create a new unit loop on top of the specific block or loop.

Parameters
----------
block_or_loop : Union[LoopRV, BlockRV]
The block above which the new loop is created

Returns
-------
new_loop : LoopRV
The new unit loop

Examples
--------

Before add_unit_loop, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_add_unit_loop(
A: T.Buffer[(), "int32"],
B: T.Buffer[(), "int32"],
C: T.Buffer[(), "int32"],
) -> None:
with T.block("C"):
vi = T.axis.spatial(1, 0)
C[()] = A[()] + B[()]

Create the schedule and do add-unit-loop:

.. code-block:: python

sch = tir.Schedule(before_add_unit_loop)
sch.add_unit_loop(sch.get_block("C"))
print(sch.mod["main"].script())

After applying add-unit-loop, the IR becomes:

.. code-block:: python

@T.prim_func
def after_add_unit_loop(
A: T.Buffer[(), "int32"],
B: T.Buffer[(), "int32"],
C: T.Buffer[(), "int32"],
) -> None:
for u in T.serial(1):
with T.block("C"):
vi = T.axis.spatial(1, 0)
C[()] = A[()] + B[()]
"""
return _ffi_api.ScheduleAddUnitLoop(self, block_or_loop) # type: ignore # pylint: disable=no-member

########## Schedule: Manipulate ForKind ##########

@type_checked
Expand Down
18 changes: 18 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,24 @@ void ConcreteScheduleNode::Reorder(const Array<LoopRV>& ordered_loop_rvs) {
this->state_->DebugVerify();
}

LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) {
LoopRV result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(block_rv)));
TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_);
this->state_->DebugVerify();
return result;
}

LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) {
LoopRV result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(loop_rv)));
TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_);
this->state_->DebugVerify();
return result;
}

/******** Schedule: Manipulate ForKind ********/

void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) {
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class ConcreteScheduleNode : public ScheduleNode {
LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) override;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;
LoopRV AddUnitLoop(const BlockRV& block_rv) override;
LoopRV AddUnitLoop(const LoopRV& loop_rv) override;
/******** Schedule: Manipulate ForKind ********/
void Parallel(const LoopRV& loop_rv) override;
void Vectorize(const LoopRV& loop_rv) override;
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,16 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs);
*/
TVM_DLL void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs);

/*!
* \brief Create a new unit loop on top of the specific block or loop.
* \param sref The block/loop above which the new thread_binding loop is created
* \param extent The extent of the new thread_binding loop
* \param thread_axis The thread axis of the new thread_binding loop
* \param attrs Extra loop attributes
junrushao marked this conversation as resolved.
Show resolved Hide resolved
* \return The new thread_binding loop
*/
TVM_DLL StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref);

/******** Schedule: Manipulate ForKind ********/
/*!
* \brief Parallelize the input loop. It requires:
Expand Down
69 changes: 69 additions & 0 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,43 @@ void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
self->Replace(GetRef<StmtSRef>(top), new_loop, {});
}

StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) {
if (sref->stmt->IsInstance<ForNode>()) {
For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef<Stmt>(sref->stmt));
self->Replace(sref, new_loop, {});
return self->stmt2ref.at(new_loop.get());
}
class NewLoopCreator : public StmtMutator {
public:
explicit NewLoopCreator(const StmtNode* src_block) : src_block_(src_block) {}

Stmt VisitStmt_(const BlockRealizeNode* realize) final {
if (realize->block.get() == src_block_) {
new_loop_ =
For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef<BlockRealize>(realize));
return new_loop_;
}
return StmtMutator::VisitStmt_(realize);
}

const StmtNode* src_block_;
For new_loop_{nullptr};
};

CHECK(sref->parent != nullptr) << "ValueError: Cannot add loops on top of the root block";
StmtSRef parent_sref = GetRef<StmtSRef>(sref->parent);
NewLoopCreator creator(sref->stmt);
Stmt new_stmt = creator(GetRef<Stmt>(parent_sref->stmt));
if (new_stmt->IsInstance<ForNode>()) {
self->Replace(parent_sref, std::move(new_stmt), {});
} else {
Block old_parent_block = GetRef<Block>(parent_sref->StmtAs<BlockNode>());
Block new_parent_block = Downcast<Block>(new_stmt);
self->Replace(parent_sref, new_stmt, {{old_parent_block, new_parent_block}});
}
return self->stmt2ref.at(creator.new_loop_.get());
}

/******** InstructionKind Registration ********/

struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
Expand Down Expand Up @@ -800,9 +837,41 @@ struct ReorderTraits : public UnpackedInstTraits<ReorderTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};

struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> {
static constexpr const char* kName = "AddUnitLoop";
static constexpr bool kIsPure = false;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 0;
static constexpr size_t kNumDecisions = 0;

static LoopRV UnpackedApplyToSchedule(Schedule sch, ObjectRef rv) {
if (const auto* block = rv.as<BlockRVNode>()) {
return sch->AddUnitLoop(GetRef<BlockRV>(block));
} else if (const auto* loop = rv.as<LoopRVNode>()) {
return sch->AddUnitLoop(GetRef<LoopRV>(loop));
} else {
LOG(FATAL) << "TypeError: AddUnitLoop expects a loop or block";
throw;
}
}

static String UnpackedAsPython(Array<String> outputs, String rv) {
PythonAPICall py("add_unit_loop");
py.Input("block_or_loop", rv);
py.SingleOutput(outputs);
return py.Str();
}

template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};

TVM_REGISTER_INST_KIND_TRAITS(SplitTraits);
TVM_REGISTER_INST_KIND_TRAITS(FuseTraits);
TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits);
TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits);

} // namespace tir
} // namespace tvm
12 changes: 12 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&Sche
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder")
.set_body_method<Schedule>(&ScheduleNode::Reorder);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop")
.set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV {
if (const auto* loop_rv = rv.as<LoopRVNode>()) {
return self->AddUnitLoop(GetRef<LoopRV>(loop_rv));
} else if (const auto* block_rv = rv.as<BlockRVNode>()) {
return self->AddUnitLoop(GetRef<BlockRV>(block_rv));
} else {
LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey()
<< ". Its value is: " << rv;
throw;
}
});
/******** (FFI) Manipulate ForKind ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel")
.set_body_method<Schedule>(&ScheduleNode::Parallel);
Expand Down
22 changes: 22 additions & 0 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,28 @@ void TracedScheduleNode::Reorder(const Array<LoopRV>& ordered_loop_rvs) {
/*outputs=*/{}));
}

LoopRV TracedScheduleNode::AddUnitLoop(const BlockRV& block_rv) {
LoopRV result = ConcreteScheduleNode::AddUnitLoop(block_rv);

static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv},
/*attrs=*/{},
/*outputs=*/{result}));
return result;
}

LoopRV TracedScheduleNode::AddUnitLoop(const LoopRV& loop_rv) {
LoopRV result = ConcreteScheduleNode::AddUnitLoop(loop_rv);

static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{loop_rv},
/*attrs=*/{},
/*outputs=*/{result}));
return result;
}

/******** Schedule: Manipulate ForKind ********/

void TracedScheduleNode::Parallel(const LoopRV& loop_rv) {
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs) final;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;
LoopRV AddUnitLoop(const BlockRV& block_rv) final;
LoopRV AddUnitLoop(const LoopRV& loop_rv) final;
/******** Schedule: Manipulate ForKind ********/
void Parallel(const LoopRV& loop_rv) final;
void Vectorize(const LoopRV& loop_rv) final;
Expand Down
58 changes: 58 additions & 0 deletions tests/python/unittest/test_tir_schedule_split_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,5 +524,63 @@ def test_fuse_not_affine():
verify_trace_roundtrip(sch=sch, mod=elementwise_not_affine)


def test_add_unit_loop_above_block():
@T.prim_func
def zero_dim(
A: T.Buffer[(), "int32"],
B: T.Buffer[(), "int32"],
C: T.Buffer[(), "int32"],
) -> None:
with T.block("C"):
vi = T.axis.spatial(1, 0)
C[()] = A[()] + B[()]

@T.prim_func
def zero_dim_added(
A: T.Buffer[(), "int32"],
B: T.Buffer[(), "int32"],
C: T.Buffer[(), "int32"],
) -> None:
for u in range(1):
with T.block("C"):
vi = T.axis.spatial(1, 0)
C[()] = A[()] + B[()]

sch = tir.Schedule(zero_dim, debug_mask="all")
block = sch.get_block("C")
sch.add_unit_loop(block)
tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"])


def test_add_unit_loop_above_loop():
@T.prim_func
def zero_dim(
A: T.Buffer[(), "int32"],
B: T.Buffer[(), "int32"],
C: T.Buffer[(), "int32"],
) -> None:
for u in range(1):
with T.block("C"):
vi = T.axis.spatial(1, 0)
C[()] = A[()] + B[()]

@T.prim_func
def zero_dim_added(
A: T.Buffer[(), "int32"],
B: T.Buffer[(), "int32"],
C: T.Buffer[(), "int32"],
) -> None:
for u1, u2 in T.grid(1, 1):
with T.block("C"):
vi = T.axis.spatial(1, 0)
C[()] = A[()] + B[()]

sch = tir.Schedule(zero_dim, debug_mask="all")
block = sch.get_block("C")
(loop,) = sch.get_loops(block)
sch.add_unit_loop(loop)
tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"])


if __name__ == "__main__":
tvm.testing.main()