Skip to content

Commit

Permalink
[TIR][Transform] Implement InlinePrivateFunctions (#16184)
Browse files Browse the repository at this point in the history
* [TIR] Update DeclBuffer nodes when specializing PrimFunc

Prior to this commit, a buffer whose parameters (e.g. shape/stride)
contained a specialized parameter would not be updated when appearing
in a `DeclBuffer` node.  This commit updates the `Specialize` function
to update buffers that occur in `DeclBuffer` nodes.

* [TIR] Handle specialization that remaps a buffer var

* [TIR] Handle specialization of buffer variable to PrimExpr

* [TIR][Transform] Implement InlinePrivateFunctions

The functionality to express a call from one `PrimFunc` to another was
introduced in #14889.  While this
was initially planned to be supported at codegen for all targets (see
#15835), this resulted in breakage
on some backends (see #16033).
After discussion, the plan was changed to support TIR inlining, which
would enable the same high-level functionality in TIR without
requiring immediate low-level support across all codegens.

This commit implements and tests a new IRModule transform
`InlinePrivateFunctions`, which can be used as part of lowering in a
follow-up commit.

Because this is initially implemented for use quite late in the
lowering flow, many constructs are not currently supported.  The
current implementation has the following restrictions.

* `tir::Block` nodes may not occur in the inlined function.  Because a
  subroutine may be called multiple times, inlining of a subroutine
  that contains `tir::Block` would result in non-unique names.
  Support of subroutines with `tir::Block` instances will require
  de-duplication of block names.

* The subroutine's callsite must occur within a `tir::Evaluate` block.
  Because inlining a subroutine inserts the `tir::Stmt` body at the
  point of use, replacement must occur in a context where a
  `tir::Stmt` can be returned.  Support of subroutines that are called
  within an expression (e.g. Replacing `func` in `Buf[0] = func(1) +
  func(2)`) would require hoisting preprocessing done in the
  subroutine to the parent `tir::Stmt`.

* The subroutine may only accept primitive arguments, and must have an
  empty `buffer_map`.  Support of subroutines that are called with
  `tir::Buffer` or `tir::BufferRegion` arguments would require a way
  to represent these arguments at the callsite, and substitution of
  the buffer into the callee.

If these unsupported constructs are used, then the inlining of those
functions is skipped.  This commit includes unit tests for these
unsupported constructs, to validate that `InlinePrivateFunctions`
produces well-formed output even when they are present.

* Updates based on review comments

* ci bump

* CI bump
  • Loading branch information
Lunderberg committed Jan 3, 2024
1 parent 97f6e65 commit 8eec0bf
Show file tree
Hide file tree
Showing 6 changed files with 764 additions and 51 deletions.
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,13 @@ TVM_DLL Pass BF16StorageLegalize();
*/
TVM_DLL Pass FP8StorageLegalize();

/*!
* \brief Inline calls to private functions
*
* \return The pass.
*/
TVM_DLL Pass InlinePrivateFunctions();

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,17 @@ def StorageRewrite():
return _ffi_api.StorageRewrite() # type: ignore


def InlinePrivateFunctions():
"""Inline calls to private functions
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InlinePrivateFunctions() # type: ignore


def PointerValueTypeRewrite():
"""
Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use
Expand Down
106 changes: 86 additions & 20 deletions src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <functional>

#include "../transforms/ir_utils.h"
#include "functor_common.h"

namespace tvm {
Expand Down Expand Up @@ -115,18 +116,18 @@ class PrimFuncSpecializer : public StmtExprMutator {
private:
Stmt VisitStmt_(const BlockNode* op) final {
// Step.0. Define buffer mappings which is allocated inside the block
Array<Buffer> alloc_buffers = op->alloc_buffers.Map(
std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
Array<Buffer> alloc_buffers =
op->alloc_buffers.Map([this](const auto& buf) { return MutateAllocBuffer(buf); });

// Step.1. Recursively visit block body
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<BlockNode>();
ICHECK(op != nullptr);

Array<BufferRegion> reads = op->reads.Map(
std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
Array<BufferRegion> writes = op->writes.Map(
std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
Array<BufferRegion> reads =
op->reads.Map([this](const auto& region) { return MutateBufferRegion(region); });
Array<BufferRegion> writes =
op->writes.Map([this](const auto& region) { return MutateBufferRegion(region); });

if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
writes.same_as(op->writes)) {
Expand All @@ -140,16 +141,54 @@ class PrimFuncSpecializer : public StmtExprMutator {
}
}

Stmt VisitStmt_(const DeclBufferNode* op) final {
// Visit the buffer before delegating to StmtExprMutator, so the
// buffer's replacement will be defined before the point of use.
Var old_buffer_var = op->buffer->data;
Buffer new_buf = MutateAllocBuffer(op->buffer);

auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));

if (!new_buf.same_as(node->buffer)) {
node.CopyOnWrite()->buffer = new_buf;
}

// If the buffer variable is being remapped to an expression, we
// still need a tir::Var to be used as a the buffer variable.
// Therefore, generate a LetStmt that will provide a tir::Var for
// the buffer to use.
//
// This step is only required when a buffer definition is using a
// previously-defined buffer variable, which is therefore eligible
// for specialization. An allocation in the
// `BlockNode::alloc_buffers` defines both the buffer variable and
// the buffer, this check is unnecessary there. In addition, if
// the buffer var has been remapped to another variable, it has already
// been handled as part of the buffer mutation.
Var new_buffer_var = node->buffer->data;
Stmt stmt = std::move(node);

if (new_buffer_var.same_as(old_buffer_var)) {
auto remapped_data = VisitExpr(old_buffer_var);
if (!remapped_data.same_as(old_buffer_var)) {
stmt = LetStmt(old_buffer_var, remapped_data, stmt);
}
}

return stmt;
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<BufferStoreNode>();
ICHECK(op != nullptr);
auto it = buffer_map_.find(op->buffer);
if (it == buffer_map_.end()) {

auto new_buf = GetNewBuffer(op->buffer);
if (new_buf.same_as(op->buffer)) {
return GetRef<BufferStore>(op);
} else {
auto n = CopyOnWrite(op);
n->buffer = it->second;
n->buffer = new_buf;
return Stmt(n);
}
}
Expand All @@ -158,12 +197,13 @@ class PrimFuncSpecializer : public StmtExprMutator {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<BufferLoadNode>();
ICHECK(op != nullptr);
auto it = buffer_map_.find(op->buffer);
if (it == buffer_map_.end()) {

auto new_buf = GetNewBuffer(op->buffer);
if (new_buf.same_as(op->buffer)) {
return GetRef<BufferLoad>(op);
} else {
auto n = make_object<BufferLoadNode>(*op);
n->buffer = it->second;
n->buffer = new_buf;
return PrimExpr(n);
}
}
Expand Down Expand Up @@ -198,17 +238,23 @@ class PrimFuncSpecializer : public StmtExprMutator {

private:
Buffer MutateBuffer(const Buffer& buffer) {
// For the data variable, only Var-to-Var remapping can be handled
// in MutateBuffer. See the DeclBuffer visitor for the handling
// of Var-to-PrimExpr remapping.
Var data = VisitExpr(buffer->data).as<Var>().value_or(buffer->data);

Array<PrimExpr> shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); });
Array<PrimExpr> strides =
buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); });

PrimExpr elem_offset = VisitExpr(buffer->elem_offset);

if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) &&
buffer->strides.same_as(strides)) {
if (buffer->data.same_as(data) && buffer->elem_offset.same_as(elem_offset) &&
buffer->shape.same_as(shape) && buffer->strides.same_as(strides)) {
return buffer;
} else {
auto n = make_object<BufferNode>(*buffer.get());
n->data = std::move(data);
n->elem_offset = std::move(elem_offset);
n->shape = std::move(shape);
n->strides = std::move(strides);
Expand All @@ -227,14 +273,33 @@ class PrimFuncSpecializer : public StmtExprMutator {
}

Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
ICHECK(!buffer_map_.count(alloc_buf))
<< "Multiple points of definition found for buffer " << alloc_buf;

Buffer buf = MutateBuffer(alloc_buf);
if (buf.same_as(alloc_buf)) {
return alloc_buf;
} else {
ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end());
buffer_map_[alloc_buf] = buf;
return buf;
buffer_map_[alloc_buf] = buf;
return buf;
}

Buffer GetNewBuffer(const Buffer& old_buffer) {
if (auto it = buffer_map_.find(old_buffer); it != buffer_map_.end()) {
return it->second;
}

auto mutated = MutateBuffer(old_buffer);
ICHECK(mutated.same_as(old_buffer))
<< "Buffer " << old_buffer << " (shape = " << old_buffer->shape << ")"
<< " was used without a declaration, "
<< "and would be specialized into " << mutated << " (shape = " << mutated->shape << "). "
<< "While usage of an undeclared buffer is currently allowed in TIR, "
<< "mutation must occur at the buffer's point of definition "
<< "(see discussion on https://github.com/apache/tvm/pull/14565 for more details). "
<< "Please add a definition for this buffer, "
<< "either in the PrimFunc's buffer_map, "
<< "in a tir::Block's alloc_buffer, "
<< "or in a DeclBuffer statement.";

return old_buffer;
}

BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) {
Expand Down Expand Up @@ -311,6 +376,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer
<< " vs. " << specific_buf->strides.size() << ".";

// Updating var mapping using specific_expr
build_var_mapping(specific_buf->data, buf_to_specialize->data);
for (size_t i = 0; i < specific_buf->shape.size(); ++i) {
build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]);
}
Expand Down

0 comments on commit 8eec0bf

Please sign in to comment.