Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Schedule] Transform layout #10538

Merged
merged 11 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/tvm/tir/index_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,34 @@ class IndexMapNode : public Object {
*/
Array<PrimExpr> MapShape(const Array<PrimExpr>& shape) const;

/*!
* \brief Convert to string representation in Python.
* \return The stringified lambda expression in Python.
*/
String ToPythonString() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("initial_indices", &initial_indices);
v->Visit("final_indices", &final_indices);
}

static constexpr const char* _type_key = "tir.IndexMap";

TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object);
};

class IndexMap : public ObjectRef {
public:
IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices);

/*!
* \brief Create an index map from a packed function
* \param ndim The number of dimensions
* \param func The function to be applied
* \return The created index map
*/
static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func);

/*! \brief Generate the inverse mapping.
*
* The range of the input indices is required in order to ensure
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/support/random_engine.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>

Expand Down Expand Up @@ -521,6 +522,21 @@ class ScheduleNode : public runtime::Object {
*/
virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0;

/******** Schedule: Layout transformation ********/
/*!
* \brief Apply a transformation represented by IndexMap to buffer
* \details The indices and the access region to the target buffer is transformed by the given
* index_map. The index_map is used to infer the new shape of the buffer. Buffer must be either
* a function parameter, or allocated in a block (it cannot be a buffer subregion created via
* 'match_buffer').
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
* \param block_rv The block that accesses the target buffer.
* \param buffer_index The index of the buffer in block's read or write region.
* \param is_write_index Whether the buffer_index is the index of the block's write region.
* \param index_map The transformation to apply.
*/
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, bool is_write_index,
const IndexMap& index_map) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
40 changes: 39 additions & 1 deletion python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# under the License.
"""Function data types."""

from typing import Mapping, Union
from typing import Callable, List, Mapping, Union
import inspect

import tvm._ffi
import tvm.runtime
Expand Down Expand Up @@ -239,3 +240,40 @@ def get(name: str):
The TensorIntrin with the specified name.
"""
return _ffi_api.TensorIntrinGet(name) # pylint: type: ignore


@tvm._ffi.register_object("tir.IndexMap")
class IndexMap(Object):
"""A mapping from multi-dimensional indices to another set of multi-dimensional indices

Parameters
----------
initial_indices : List[Var]
Variables representing the indices prior to remapping.
final_indices : List[PrimExpr]
Expressions defining the indices after remapping.
"""

initial_indices: List[Var]
final_indices: List[PrimExpr]

@staticmethod
def from_func(func: Callable) -> "IndexMap":
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
"""Create an index map from a function

Parameters
----------
func : Callable
The function to map from source indices to target indices
"""

def wrap(args: List[Var]) -> List[PrimExpr]:
result = func(*args)
if isinstance(result, tuple):
return list(result)
if not isinstance(result, list):
result = [result]
return result

ndim = len(inspect.signature(func).parameters)
return _ffi_api.IndexMapFromFunc(ndim, wrap) # type: ignore # pylint: disable=no-member
79 changes: 78 additions & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""The TensorIR schedule class"""
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc
from ..function import IndexMap

from . import _ffi_api
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
Expand Down Expand Up @@ -2111,6 +2112,82 @@ def after_unannotate(a: T.handle, b: T.handle) -> None:
self, block_or_loop, ann_key
)

########## Schedule: Layout transformation ##########

@type_checked
def transform_layout(
vinx13 marked this conversation as resolved.
Show resolved Hide resolved
self,
block: BlockRV,
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
buffer_index: int,
is_write_index: bool,
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
index_map: Union[IndexMap, Callable],
) -> None:
"""Apply a transformation represented by IndexMap to buffer
Parameters
----------
block_rv : BlockRV
The block that accesses the target buffer
buffer_index: int
The index of the buffer in block's read or write region
is_write_index : bool
Whether the buffer_index is the index of the block's write region
index_map : Union[IndexMap, Callable]
The transformation to apply

Examples
--------
Before transform_layout, in TensorIR, the IR is:
vinx13 marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

@T.prim_func
def before_transform_layout(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128), "float32")
B = T.alloc_buffer((128, 128), "float32")
C = T.match_buffer(c, (128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do transform_layout:

.. code-block:: python

sch = tir.Schedule(before_storage_align)
sch.transform_layout(sch.get_block("B"), buffer_index=0, is_write_index=True,
index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16))
print(sch.mod["main"].script())

After applying transform_layout, the IR becomes:

.. code-block:: python

@T.prim_func
def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128), "float32")
B = T.alloc_buffer((8, 8, 16, 16), "float32")
C = T.match_buffer(c, (128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0

"""
if callable(index_map):
index_map = IndexMap.from_func(index_map)
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member
self, block, buffer_index, is_write_index, index_map
)

########## Schedule: Misc ##########

@type_checked
Expand Down
51 changes: 51 additions & 0 deletions src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/arith/int_set.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <sstream>

Expand All @@ -40,6 +41,15 @@ IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices) {
data_ = std::move(n);
}

IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func) {
Array<Var> initial_indices;
initial_indices.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32)));
}
return IndexMap(initial_indices, func(initial_indices));
}

IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
// Dummy variables to represent the inverse's inputs.
Array<Var> output_vars;
Expand Down Expand Up @@ -142,6 +152,40 @@ Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape) const {
return output;
}

String IndexMapNode::ToPythonString() const {
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
std::unordered_set<std::string> used_names;
Map<Var, PrimExpr> var_remap;
for (const Var& initial_index : initial_indices) {
if (used_names.count(initial_index->name_hint)) {
std::string new_name = initial_index->name_hint + std::to_string(used_names.size());
used_names.insert(new_name);
var_remap.Set(initial_index, Var(new_name));
} else {
used_names.insert(initial_index->name_hint);
}
}
std::ostringstream oss;
oss << "lambda ";
for (size_t i = 0; i < initial_indices.size(); ++i) {
if (i != 0) {
oss << ", ";
}
auto it = var_remap.find(initial_indices[i]);
if (it != var_remap.end()) {
oss << (*it).second;
} else {
oss << initial_indices[i];
}
}
oss << ": (";
for (size_t i = 0; i < final_indices.size(); ++i) {
oss << Substitute(final_indices[i], var_remap);
oss << ", ";
}
oss << ")";
return String(oss.str());
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IndexMapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IndexMapNode*>(node.get());
Expand All @@ -150,5 +194,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

TVM_REGISTER_NODE_TYPE(IndexMapNode);

TVM_REGISTER_GLOBAL("tir.IndexMap")
.set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices) {
return IndexMap(initial_indices, final_indices);
});

TVM_REGISTER_GLOBAL("tir.IndexMapFromFunc").set_body_typed(IndexMap::FromFunc);

} // namespace tir
} // namespace tvm
10 changes: 10 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@ struct ProducerConsumerSplit {
*/
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);

/*!
* \brief Find the defining site of the buffer in the given block and its ancestors
* \param block_sref The block sref
* \param buffer The buffer
* \return The defining site of the buffer and whether the buffer is allocated (otherwise the
* buffer is from match_buffer).
*/
std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_sref,
const Buffer& buffer);

/******** Reduction Block Related ********/

/*!
Expand Down
31 changes: 31 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,37 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,
return access_region[n]->buffer;
}

std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_sref,
const Buffer& buffer) {
// Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or
// match_buffers.
const StmtSRefNode* defining_site_sref = block_sref.get();
while (defining_site_sref != nullptr) {
const auto* block = defining_site_sref->StmtAs<BlockNode>();
// If this sref is not a block sref, skip it.
if (block == nullptr) {
defining_site_sref = defining_site_sref->parent;
continue;
}
// Try to find the buffer in `allloc_buffers`
for (const Buffer& alloc_buffer : block->alloc_buffers) {
if (buffer.same_as(alloc_buffer)) {
return {GetRef<StmtSRef>(defining_site_sref), true};
}
}
// We do not allow the buffer being defined in `match_buffer`.
for (const MatchBufferRegion match_buffer : block->match_buffers) {
if (buffer.same_as(match_buffer)) {
return {GetRef<StmtSRef>(defining_site_sref), false};
}
}
defining_site_sref = defining_site_sref->parent;
}
// If we cannot find the defining site block, it means that the buffer must be in the function's
// buffer_map, which isn't an intermediate buffer.
return {NullOpt, false};
}

/******** Pattern Matcher ********/

/*!
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,15 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann
TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_);
}

/******** Schedule: Layout transformation ********/
void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index,
bool is_write_index, const IndexMap& index_map) {
TVM_TIR_SCHEDULE_BEGIN();
tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, is_write_index, index_map);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
}

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ class ConcreteScheduleNode : public ScheduleNode {
void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
void Annotate(const BlockRV& block_rv, const String& ann_key, const ObjectRef& ann_val) override;
void Unannotate(const BlockRV& block_rv, const String& ann_key) override;

/******** Schedule: Layout transformation ********/
void TransformLayout(const BlockRV& block_rv, int buffer_index, bool is_write_index,
const IndexMap& index_map) override;
/******** Schedule: Misc ********/
void EnterPostproc() override {}

Expand Down
16 changes: 16 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,22 @@ TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& an
*/
TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key);

/******** Schedule: Layout transformation ********/
/*!
* \brief Apply a transformation represented by IndexMap to buffer
* \details The indices and the access region to the target buffer is transformed by the given
* index_map. The index_map is also used to infer the new shape of the buffer. Buffer must be
* one of the parameter of the function, or allocated in some blocks (it cannot be a buffer
* subregion created via match_buffer).
* \param self The state of the schedule
* \param block_sref The block sref that accesses the target buffer.
* \param buffer_index The index of the buffer in block's read or write region.
* \param is_write_index Whether the buffer_index is the index of the block's write region.
* \param index_map The transformation to apply.
*/
TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
bool is_write_index, const IndexMap& index_map);

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
Loading