Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Schedule] Transform layout #10538

Merged
merged 11 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/tvm/tir/index_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,34 @@ class IndexMapNode : public Object {
*/
Array<PrimExpr> MapShape(const Array<PrimExpr>& shape) const;

/*!
* \brief Convert to string representation in Python.
* \return The stringified lambda expression in Python.
*/
String ToPythonString() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("initial_indices", &initial_indices);
v->Visit("final_indices", &final_indices);
}

static constexpr const char* _type_key = "tir.IndexMap";

TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object);
};

class IndexMap : public ObjectRef {
public:
IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices);

/*!
* \brief Create an index map from a packed function
* \param ndim The number of dimensions
* \param func The function to be applied
* \return The created index map
*/
static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func);

/*! \brief Generate the inverse mapping.
*
* The range of the input indices is required in order to ensure
Expand Down
24 changes: 24 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/support/random_engine.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>

Expand All @@ -36,6 +37,14 @@ enum class ScheduleErrorRenderLevel : int32_t {
kNone = 2,
};

/*! \brief Type of buffer index */
enum class BufferIndexType : int32_t {
/*! \brief Index of a read buffer */
kRead = 0,
/*! \brief Index of a written buffer */
kWrite = 1,
};

/**************** Random variable: BlockRV ****************/

/*! \brief A random variable that evaluates to a TensorIR block */
Expand Down Expand Up @@ -521,6 +530,21 @@ class ScheduleNode : public runtime::Object {
*/
virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0;

/******** Schedule: Layout transformation ********/
/*!
* \brief Apply a transformation represented by IndexMap to buffer
* \details The indices and the access region to the target buffer is transformed by the given
* index_map. The index_map is used to infer the new shape of the buffer. Buffer must be either
* a function parameter, or allocated in a block (it cannot be a buffer subregion created via
* 'match_buffer').
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
* \param block_rv The block that accesses the target buffer.
* \param buffer_index The index of the buffer in block's read or write region.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \param index_map The transformation to apply.
*/
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError, BufferType

from . import schedule
from . import ir_builder
Expand Down
58 changes: 57 additions & 1 deletion python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# under the License.
"""Function data types."""

from typing import Mapping, Union
from typing import Callable, List, Mapping, Optional, Union
import inspect

import tvm._ffi
import tvm.runtime
Expand Down Expand Up @@ -239,3 +240,58 @@ def get(name: str):
The TensorIntrin with the specified name.
"""
return _ffi_api.TensorIntrinGet(name) # pylint: type: ignore


@tvm._ffi.register_object("tir.IndexMap")
class IndexMap(Object):
"""A mapping from multi-dimensional indices to another set of multi-dimensional indices

Parameters
----------
initial_indices : List[Var]
Variables representing the indices prior to remapping.
final_indices : List[PrimExpr]
Expressions defining the indices after remapping.
"""

initial_indices: List[Var]
final_indices: List[PrimExpr]

def __init__(self, initial_indices, final_indices):
self.__init_handle_by_constructor__(_ffi_api.IndexMap, initial_indices, final_indices)

@staticmethod
def from_func(mapping_function: Callable, ndim: Optional[int] = None):
"""Create an index map from a function

Parameters
----------
mapping_function : Callable
The function to map from source indices to target indices
"""
params = inspect.signature(mapping_function).parameters
default_index_dtype = "int32"
args = []
var_arg_name = None
for name, param in params.items():
if param.kind in [
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
]:
args.append(tvm.tir.Var(name, default_index_dtype))
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
var_arg_name = name
else:
raise ValueError("transform_layout mapping may not have *args or **kwargs")

# Now that all the named arguments have been collected,
# everything that remains should go to the *args, if
# specified.
if var_arg_name is not None:
assert ndim is not None, "ndim must be specified when *args is used"
num_var_args = ndim - len(args)
for i in range(num_var_args):
args.append(tvm.tir.Var(f"{var_arg_name}_{i}", default_index_dtype))

final_indices = mapping_function(*args)
return IndexMap(args, final_indices)
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@

from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
from .instruction import Instruction, InstructionKind
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError, BufferType
from .state import ScheduleDebugMask, ScheduleState
from .trace import Trace
87 changes: 86 additions & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@
# specific language governing permissions and limitations
# under the License.
"""The TensorIR schedule class"""
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc
from ..function import IndexMap

from . import _ffi_api
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
from .trace import Trace
from ._type_checker import type_checked
import enum


@register_error
Expand Down Expand Up @@ -71,6 +73,13 @@ def __init__(self) -> None:
}


class BufferType(enum.IntEnum):
"""Type of buffer in access regions of a block"""

READ = 0
WRITE = 1

Comment on lines +76 to +81
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinx13 Sorry I'm late and missed the code review. On this particular change, I'm 100% in favor of having enum type on C++ side; However, it could be more to "just use string" on the python side, i.e. use "read" and "write" to indicate BufferType. If you agree with my opinion, would you mind sending a quick patch? Thanks a lot!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds good to me as well, and would give the same readability benefits at the caller side. I kind of like the enum usage in python as well, since it lets the reader know what options are possible, but that's the less important benefit.

At some point, it would be nice to have a macro to define an enum in C++, along with its value/name mapping and an FFI interface, so that there would be a clear way to handle these.


def _parse_error_render_level(error_render_level: str) -> int:
if error_render_level not in _ERROR_RENDER_LEVEL:
raise ValueError(
Expand Down Expand Up @@ -2111,6 +2120,82 @@ def after_unannotate(a: T.handle, b: T.handle) -> None:
self, block_or_loop, ann_key
)

########## Schedule: Layout transformation ##########

@type_checked
def transform_layout(
vinx13 marked this conversation as resolved.
Show resolved Hide resolved
self,
block: BlockRV,
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
buffer_index: int,
buffer_type: BufferType,
index_map: Union[IndexMap, Callable],
) -> None:
"""Apply a transformation represented by IndexMap to buffer
Parameters
----------
block_rv : BlockRV
The block that accesses the target buffer
buffer_index: int
The index of the buffer in block's read or write region
buffer_type : BufferType
Type of the buffer, READ or WRITE.
index_map : Union[IndexMap, Callable]
The transformation to apply

Examples
--------
Before transform_layout, in TensorIR, the IR is:
vinx13 marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

@T.prim_func
def before_transform_layout(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128), "float32")
B = T.alloc_buffer((128, 128), "float32")
C = T.match_buffer(c, (128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do transform_layout:

.. code-block:: python

sch = tir.Schedule(before_storage_align)
sch.transform_layout(sch.get_block("B"), buffer_index=0, BufferType.WRITE,
index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16))
print(sch.mod["main"].script())

After applying transform_layout, the IR becomes:

.. code-block:: python

@T.prim_func
def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128), "float32")
B = T.alloc_buffer((8, 8, 16, 16), "float32")
C = T.match_buffer(c, (128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0

"""
if callable(index_map):
index_map = IndexMap.from_func(index_map)
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member
self, block, buffer_index, buffer_type, index_map
)

########## Schedule: Misc ##########

@type_checked
Expand Down
51 changes: 50 additions & 1 deletion src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/arith/int_set.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <sstream>

Expand All @@ -40,6 +41,15 @@ IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices) {
data_ = std::move(n);
}

IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func) {
Array<Var> initial_indices;
initial_indices.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32)));
}
return IndexMap(initial_indices, func(initial_indices));
}

IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
// Dummy variables to represent the inverse's inputs.
Array<Var> output_vars;
Expand Down Expand Up @@ -142,13 +152,52 @@ Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape) const {
return output;
}

String IndexMapNode::ToPythonString() const {
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
std::unordered_set<std::string> used_names;
Map<Var, PrimExpr> var_remap;
for (const Var& initial_index : initial_indices) {
if (used_names.count(initial_index->name_hint)) {
std::string new_name = initial_index->name_hint + std::to_string(used_names.size());
used_names.insert(new_name);
var_remap.Set(initial_index, Var(new_name));
} else {
used_names.insert(initial_index->name_hint);
}
}
std::ostringstream oss;
oss << "lambda ";
for (size_t i = 0; i < initial_indices.size(); ++i) {
if (i != 0) {
oss << ", ";
}
auto it = var_remap.find(initial_indices[i]);
if (it != var_remap.end()) {
oss << (*it).second;
} else {
oss << initial_indices[i];
}
}
oss << ": (";
for (size_t i = 0; i < final_indices.size(); ++i) {
oss << Substitute(final_indices[i], var_remap);
oss << ", ";
}
oss << ")";
return String(oss.str());
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IndexMapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IndexMapNode*>(node.get());
p->stream << "index_map(" << op->initial_indices << ", " << op->final_indices << ")";
p->stream << "index_map(" << op->ToPythonString() << ")";
});

TVM_REGISTER_NODE_TYPE(IndexMapNode);

TVM_REGISTER_GLOBAL("tir.IndexMap")
.set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices) {
return IndexMap(initial_indices, final_indices);
});

} // namespace tir
} // namespace tvm
10 changes: 10 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@ struct ProducerConsumerSplit {
*/
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);

/*!
* \brief Find the defining site of the buffer in the given block and its ancestors
* \param block_sref The block sref
* \param buffer The buffer
* \return The defining site of the buffer and whether the buffer is allocated (otherwise the
* buffer is from match_buffer).
*/
std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_sref,
const Buffer& buffer);

/******** Reduction Block Related ********/

/*!
Expand Down
Loading