Skip to content

Commit

Permalink
[TIR][Schedule] Enhance compute-inline for fusion (#15142)
Browse files Browse the repository at this point in the history
This PR enhances `compute-inline` and `reverse-compute-inline` to handle
more complicated fusion patterns where the RHS of the equation could be
generic expressions rather than strict buffer indexing. This could be
used in multiple cases in LLM inference.
  • Loading branch information
junrushao committed Jun 24, 2023
1 parent 512d35a commit 5a3523d
Show file tree
Hide file tree
Showing 7 changed files with 510 additions and 183 deletions.
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/testing/space_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def check_sketches(
def print_sketches(sketches: List[Schedule]):
for i, sch in enumerate(sketches):
print(f"###### {i}")
sch.mod.show()
sch.mod.show(black_format=False)
for inst in sch.trace.insts:
if inst in sch.trace.decisions:
print(f'("{inst.kind.name}", {sch.trace.decisions[inst]}),')
262 changes: 149 additions & 113 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ namespace tvm {
namespace tir {

static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of
'A[i, j, k, ...] = f(i, j, k, ...)',
where the indices on the left are distinct atomic variables,
and there should be no variables other than the index variables)";
'A[f(i, j, k, ...)] = g(i, j, k, ...)',
where the store indices mapping f on the left are bijective affine.)";

static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of
`B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`,
Expand Down Expand Up @@ -172,20 +171,53 @@ class NonSingleProducerError : public ScheduleError {
*/
static StmtSRef Check(const ScheduleState& self, const StmtSRef& consumer_block_sref,
const StmtSRef& scope_root_sref) {
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<Dependency> producers = scope->GetDepsByDst(consumer_block_sref);
StmtSRef producer_block_sref{nullptr};
if (producers.size() == 1 && producers[0]->kind == DepKind::kRAW) {
producer_block_sref = producers[0]->src;
if (IsCompleteBlock(self, producer_block_sref, scope_root_sref)) {
Array<Dependency> consumers = scope->GetDepsBySrc(producer_block_sref);
if (consumers.size() == 1) {
return producer_block_sref;
const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_root_sref);
const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref);
Buffer consumer_buffer = NotSingleReadWriteBuffer::GetSingleRead(
self, GetRef<Block>(consumer_block), scope_root_sref);
class ProducerFinder : public StmtVisitor {
public:
static std::vector<Block> GetProducer(const Buffer& buffer, const Block& scope_block) {
ProducerFinder finder(buffer);
finder(scope_block);
return finder.producer_across_scope_.back();
}

private:
explicit ProducerFinder(const Buffer& buffer) : buffer_(buffer) {
producer_across_scope_.push_back({});
}

void VisitStmt_(const BlockNode* node) final {
producer_across_scope_.push_back({});
StmtVisitor::VisitStmt_(node);
// not a leaf block
if (!producer_across_scope_.back().empty()) {
auto producer_under_block = producer_across_scope_.back();
producer_across_scope_.pop_back();
producer_across_scope_.back().insert(producer_across_scope_.back().end(),
producer_under_block.begin(),
producer_under_block.end());
return;
}
// leaf block
producer_across_scope_.pop_back();
for (const auto& write : node->writes) {
if (write->buffer.same_as(buffer_)) {
producer_across_scope_.back().push_back(GetRef<Block>(node));
break;
}
}
}
Buffer buffer_;
std::vector<std::vector<Block>> producer_across_scope_;
};
std::vector<Block> producer_across_scope =
ProducerFinder::GetProducer(consumer_buffer, GetRef<Block>(scope_block));
if (producer_across_scope.size() != 1) {
throw NonSingleProducerError(self->mod, GetRef<Block>(consumer_block));
}
const BlockNode* block = TVM_SREF_TO_BLOCK(consumer_block_sref);
throw NonSingleProducerError(self->mod, GetRef<Block>(block));
return self->stmt2ref.at(producer_across_scope[0].get());
}
};

Expand Down Expand Up @@ -269,7 +301,7 @@ class BaseInliner : public StmtExprMutator {
return StmtExprMutator::VisitStmt_(loop);
}

Stmt VisitStmt_(const BlockNode* block) final {
Stmt VisitStmt_(const BlockNode* block) {
CheckMatchBufferRegion(block);
AddBuffersInBlockSignature(block);
Block src_block = GetRef<Block>(block);
Expand All @@ -284,31 +316,6 @@ class BaseInliner : public StmtExprMutator {
return std::move(tgt_block);
}

/*!
* \brief Count the number of undefined variables that are not used
* as buffer objects.
*
* This is used to determine whether inlining or reverse inlining is
* possible. The only undefined variables present should be the
* load/store indices, or buffer access based on those indices.
*
* \param stmt The statement in which to count undefined variables
*/
static int GetNumUndefinedNonpointerVars(const Stmt& stmt) {
auto undefined_vars = UndefinedVars(stmt, {});
// Buffer pointers and the inlined indices are allowed, but no
// other variables may appear in the inlined block.
int num_nonpointer_vars = 0;
for (const auto& var : undefined_vars) {
bool is_pointer = var->dtype.is_handle() && var->type_annotation.defined() &&
var->type_annotation.as<PointerTypeNode>();
if (!is_pointer) {
num_nonpointer_vars++;
}
}
return num_nonpointer_vars;
}

private:
/*!
* \brief Add the buffers in the block signature to the `buffer_var_map_`,
Expand Down Expand Up @@ -406,7 +413,7 @@ class BaseInliner : public StmtExprMutator {
/*! \brief Maps a buffer's data field to itself */
Map<Var, Buffer> buffer_var_map_;
/*! \brief The indices used for indexing the buffer to be inlined */
std::vector<const VarNode*> idx_vars_;
std::vector<Var> idx_vars_;
/*! \brief The mapping to substitute index variables to PrimExprs */
std::unordered_map<const VarNode*, PrimExpr> idx_sub_;

Expand Down Expand Up @@ -443,10 +450,62 @@ class ComputeInliner : public BaseInliner {
return false;
}

int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_));
if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) {
// Fast path on trivial case:
// Check the store indices are same with the block iters;
store_value_ = inlined_store_->value;
size_t num_iters = producer_block->iter_vars.size();
size_t buffer_ndim = inlined_store_->indices.size();
if (num_iters == buffer_ndim) {
std::vector<Var> idx_vars;
idx_vars.reserve(num_iters);
for (size_t i = 0; i < num_iters; ++i) {
const IterVar& iter = producer_block->iter_vars[i];
const PrimExpr& e = inlined_store_->indices[i];
if (e.same_as(iter->var) ||
(analyzer_.CanProveEqual(e, 0) && analyzer_.CanProveEqual(iter->dom->min, 0) &&
analyzer_.CanProveEqual(iter->dom->extent, 1))) {
idx_vars.push_back(iter->var);
} else {
break;
}
}
if (idx_vars.size() == num_iters) {
// match success
idx_vars_ = std::move(idx_vars);
return true;
}
}

// If the mapping for store indices is non-trivial
// check bijective mapping from producer iter var to store indices
Map<Var, Range> producer_iter_doms;
for (const auto& iter : producer_block->iter_vars) {
producer_iter_doms.Set(iter->var, iter->dom);
}
arith::IterMapResult res = arith::DetectIterMap(
/*indices=*/inlined_store_->indices,
/*input_iters=*/producer_iter_doms,
/*predicate=*/true,
/*check_level=*/arith::IterMapLevel::Bijective,
/*analyzer=*/&analyzer_,
/*simplify_trivial_iterators=*/false);
if (!res->errors.empty()) {
// Failure: indices of BufferStore are not bijective affine
return false;
}
idx_vars_.resize(buffer_ndim);
for (size_t i = 0; i < idx_vars_.size(); ++i) {
idx_vars_[i] = Var("ph_" + std::to_string(i), inlined_store_->indices[i].dtype());
}
auto inverse_iter_map = arith::InverseAffineIterMap(
res->indices, Array<PrimExpr>(idx_vars_.begin(), idx_vars_.end()));
for (const auto& iter : producer_block->iter_vars) {
if (is_const_int(iter->dom->min) && analyzer_.CanProveEqual(iter->dom->extent, 1)) {
// fallback mapping for constant iters
inverse_iter_map.Set(iter->var, iter->dom->min);
}
}
store_value_ = Substitute(store_value_, inverse_iter_map);
return true;
}

Expand All @@ -464,45 +523,7 @@ class ComputeInliner : public BaseInliner {

PrimExpr ReplaceInlinedBuffer(BufferLoad load) {
SetIndexSubstitution(load->indices);
return Substitute(inlined_store_->value, idx_sub_);
}

/*!
* \brief Check if the indices are atomic distinct variables and the access is n-dimensional.
* If so, set `self->idx_vars_` properly.
* \param indices The indices to be extracted
* \param expected_ndim The expected ndim of the access
* \return A boolean flag indicating if the check is successful
*/
bool UpdateAndCheckIndexVars(const Array<PrimExpr>& indices, int expected_ndim) {
int n = indices.size();
if (n != expected_ndim) {
// Failure: dimension mismatch
return false;
}
std::vector<const VarNode*> result;
result.reserve(n);
for (const PrimExpr& i : indices) {
if (const auto* var = i.as<VarNode>()) {
result.push_back(var);
} else {
// Failure: indexing expression is not a variable
return false;
}
}
using DistinctSet = std::unordered_set<const VarNode*>;
int n_distinct = DistinctSet(result.begin(), result.end()).size();
if (n != n_distinct) {
// Failure: indexing variables are not distinct
return false;
}
if (idx_vars_.empty()) {
idx_vars_ = std::move(result);
} else if (!support::ArrayWithSameContent(idx_vars_, result)) {
// Failure: indexing variables are not consitent in different BufferLoads
return false;
}
return true;
return Substitute(store_value_, idx_sub_);
}

/*!
Expand All @@ -512,11 +533,17 @@ class ComputeInliner : public BaseInliner {
void SetIndexSubstitution(const Array<PrimExpr>& indices) {
ICHECK_EQ(indices.size(), idx_vars_.size());
int n = idx_vars_.size();
idx_sub_.reserve(n);
for (int i = 0; i < n; ++i) {
idx_sub_[idx_vars_[i]] = indices[i];
idx_sub_[idx_vars_[i].get()] = indices[i];
}
}

/*! \brief The arithmetic analyzer */
arith::Analyzer analyzer_;
/*! \brief The store value for inlinement. If the producer
store indices are trivial, it is wrt the producer block iter var,
otherwise it is wrt to the placeholder vars of store indices. */
PrimExpr store_value_;
};

/*!
Expand All @@ -534,7 +561,9 @@ class ReverseComputeInliner : public BaseInliner {
private:
PrimExpr VisitExpr_(const VarNode* var) final {
auto it = self_->idx_sub_.find(var);
ICHECK(it != self_->idx_sub_.end());
if (it == self_->idx_sub_.end()) {
return GetRef<Var>(var);
}
return (*it).second;
}

Expand All @@ -552,8 +581,7 @@ class ReverseComputeInliner : public BaseInliner {
const StmtSRef& scope_root_sref, const IRModule& mod)
: BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref),
producer_block_(producer_block),
consumer_block_(consumer_block_realize->block.get()),
mod_(mod) {
consumer_block_(consumer_block_realize->block.get()) {
// Initialize the predicates to ensure consumer block iters are in-bound
consumer_iter_in_bound_ = Bool(true);
for (const IterVar& iter : consumer_block_realize->block->iter_vars) {
Expand Down Expand Up @@ -596,7 +624,7 @@ class ReverseComputeInliner : public BaseInliner {
}
}

auto res = arith::DetectIterMap(
arith::IterMapResult res = arith::DetectIterMap(
/*indices=*/buffer_load_indices_,
/*input_iters=*/consumer_iter_doms,
/*predicate=*/true,
Expand All @@ -609,7 +637,15 @@ class ReverseComputeInliner : public BaseInliner {
return false;
}

const BufferStoreNode* producer_store = producer_block_->body.as<BufferStoreNode>();
const BufferStoreNode* producer_store = nullptr;
if (const auto* producer_if = producer_block_->body.as<tir::IfThenElseNode>()) {
if (producer_if->else_case.defined()) {
return false;
}
producer_store = producer_if->then_case.as<BufferStoreNode>();
} else {
producer_store = producer_block_->body.as<BufferStoreNode>();
}
if (producer_store == nullptr) {
// Failure: producer block body is not BufferStore
return false;
Expand All @@ -628,39 +664,41 @@ class ReverseComputeInliner : public BaseInliner {
using BaseInliner::VisitStmt_;

/*! \brief Generate the predicate after inlining based on the consumer predicate */
PrimExpr BuildInlinedConsumerPredicate(const BlockRealizeNode* producer_block_realize) {
Block BuildInlinedConsumerPredicate(const BlockNode* producer_block) {
// Bind the producer block iter domains for simplification
Map<Var, PrimExpr> subst_map;
for (int i = 0, n = producer_block_realize->iter_values.size(); i < n; ++i) {
const IterVar& iter = producer_block_realize->block->iter_vars[i];
for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) {
const IterVar& iter = producer_block->iter_vars[i];
analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
subst_map.Set(iter->var, producer_block_realize->iter_values[i]);
}
// Substitute the consumer block iters with the corresponding iters in the producer blocks
PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_);
// Simplify the predicate using the producer block iter domains
predicate = analyzer_.Simplify(predicate);
// Substitute the producer block iters with the its bindings since the predicate in BlockRealize
// should not contain the block iters
predicate = Substitute(predicate, subst_map);
predicate = analyzer_.Simplify(predicate);
return predicate;
ObjectPtr<BlockNode> block = make_object<BlockNode>(*producer_block);
if (is_one(predicate)) {
return Block(block);
}
if (const auto* if_ = producer_block->body.as<tir::IfThenElseNode>()) {
PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
if (!StructuralEqual()(predicate, if_predicate)) {
predicate = analyzer_.Simplify(predicate && if_->condition);
}
block->body = IfThenElse(predicate, if_->then_case);
return Block(block);
}
block->body = IfThenElse(predicate, block->body);
return Block(block);
}

Stmt VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize new_block_realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
if (op->block.get() == producer_block_) {
auto new_predicate = BuildInlinedConsumerPredicate(new_block_realize.get());

With<arith::ConstraintContext> ctx(&analyzer_, new_predicate);
if (!analyzer_.CanProve(op->predicate)) {
// We do not allow cases where the new predicate for the inlined block cannot
// imply the original predicate in the producer block.
throw ProducerHasNonTrivialPredicateError(mod_, GetRef<BlockRealize>(op), new_predicate);
}
new_block_realize.CopyOnWrite()->predicate = new_predicate;
Stmt VisitStmt_(const BlockNode* op) final {
Block src_block = GetRef<Block>(op);
Block tgt_block = Downcast<Block>(BaseInliner::VisitStmt_(op));
if (op == producer_block_) {
tgt_block = BuildInlinedConsumerPredicate(tgt_block.get());
block_reuse.Set(src_block, tgt_block);
}
return std::move(new_block_realize);
return std::move(tgt_block);
}

Stmt VisitStmt_(const BufferStoreNode* _store) final {
Expand Down Expand Up @@ -774,8 +812,6 @@ class ReverseComputeInliner : public BaseInliner {
PrimExpr consumer_iter_in_bound_{nullptr};
/*! \brief The arithmetic analyzer */
arith::Analyzer analyzer_;
/*! \brief The target module, only used for error reporting. */
const IRModule& mod_;
};

void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
Expand Down
Loading

0 comments on commit 5a3523d

Please sign in to comment.