Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add schedule primitive ReIndex #11515

Merged
merged 1 commit into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,19 @@ class ScheduleNode : public runtime::Object {
*/
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) = 0;
/*!
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
* The layout of the cache will be the same as by the iterators of the block that reads/writes the
* buffer. It requires:
* 1) There is only one block who reads/writes the target buffer
* 2) There is only one buffer load/store of this buffer in the block
* \param block_rv The block operates on the target buffer.
* \param buffer_index The index of the buffer in block's read or write region.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \return The reindex stage block.
*/
virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) = 0;
/******** Schedule: Compute location ********/
/*!
* \brief Move a producer block under the specific loop, and regenerate the
Expand Down
73 changes: 73 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,79 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
self, block, write_buffer_index, storage_scope
)

@type_checked
def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: str) -> BlockRV:
"""Create a block that read/write a buffer region into a read/write cache with reindexing.
The layout of the cache will be the same as by the iterators of the block that reads/writes
the buffer. It requires:
1) There is only one block who reads/writes the target buffer
2) There is only one buffer load/store of this buffer in the block

Parameters
----------
block: BlockRV
The block that accesses the target buffer
buffer_index: int
The index of the buffer in block's read or write region
buffer_index_type : str
Type of the buffer index, "read" or "write"

Returns
-------
reindex_block : BlockRV
The block of the reindex stage

Examples
--------

Before transform_layout, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_reindex(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"]
) -> None:
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vj, vi] * 2.0

Create the schedule and do transform_layout:

.. code-block:: python

sch = tir.Schedule(before_reindex)
block = sch.get_block("B")
sch.reindex(block, 0, "read)

After applying reindex, the IR becomes:

.. code-block:: python

@T.prim_func
def after_reindex(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"]
) -> None:
A_reindex = T.alloc_buffer((128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("A_reindex"):
vi, vj = T.axis.remap("SS", [i, j])
A_reindex[vi, vj] = A[vj, vi]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A_reindex[vi, vj] * 2.0

"""
assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type"
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
return _ffi_api.ScheduleReIndex( # type: ignore # pylint: disable=no-member
self, block, buffer_index, buffer_index_type_enum
)

########## Schedule: Compute location ##########

@type_checked
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,16 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type);
TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

/******** Schedule: Compute location ********/

void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class ConcreteScheduleNode : public ScheduleNode {
const String& storage_scope) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) override;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
Expand Down
15 changes: 15 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,21 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r
*/
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope);
/*!
*!
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
* The layout of the cache will be the same as by the iterators of the block that reads/writes the
* buffer. It requires:
* 1) There is only one block who reads/writes the target buffer
* 2) There is only one buffer load/store of this buffer in the block
* \param self The state of the schedule
* \param block_rv The block operates on the target buffer.
* \param buffer_index The index of the buffer in block's read or write region.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \return The reindex stage block.
*/
TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
BufferIndexType buffer_index_type);
/******** Schedule: Compute location ********/
/*!
* \brief Move a producer block under the specific loop, and regenerate the
Expand Down
Loading