Skip to content

Commit

Permalink
Re-design Indices lowering (apache#29)
Browse files Browse the repository at this point in the history
* upd

* upd

* upd

* upd

* upd

* init

* format

* fix

* revise coding-style

* format
  • Loading branch information
yzh119 committed Dec 9, 2021
1 parent f25aa07 commit 0b0e371
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 402 deletions.
170 changes: 0 additions & 170 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,176 +404,6 @@ std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters()
bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,
CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs);

/******** SparseTIR Tools ********/

/*!
* \brief Maps sparse buffers to the array of sparse iterators we used to index the buffer.
*/
using BufferAccessMap = Map<SparseBuffer, Array<SpIterVar>>;
/*!
* \brief Maps sparse_iter to (sparse_buffer, i), indicates sparse_iter was used
* in the i-th dimension of sparse_buffer.
*/
using DependencyMap =
std::unordered_map<SpIterVar, std::pair<SparseBuffer, int>, ObjectPtrHash, ObjectPtrEqual>;

/*!
* \brief Check whether a given SparseBuffer contains the given axis.
* \param buffer The SparseBuffer to be checked.
* \param axis The axis to be checked.
* \return A boolean indicating whether the given SparseBuffer contains the
* given axis
*/
bool BufferContainsAxis(const SparseBuffer& buffer, const Axis& axis);

/*!
* \brief For each sparse-fixed or sparse-variable iterator, collect the
* iterators that it depends on.
*/
class AccessAndDependencyCollector : public StmtExprVisitor {
public:
/*!
* \brief Collect access and dependency information from the given statement.
* \param stmt The statement node to collect in the AST.
*/
void Collect(Stmt stmt) {
VisitStmt(std::move(stmt));

for (const std::pair<SparseBuffer, Array<SpIterVar>>& kv_pair : buffer_access_map_) {
const SparseBuffer& buffer = kv_pair.first;
const Array<SpIterVar>& sp_iters = kv_pair.second;
int ndim = static_cast<int>(sp_iters.size());
for (int k = 0; k < ndim; ++k) {
const SpIterVar& sp_iter = sp_iters[k];
if (sp_iter->kind == SpIterKind::kDenseFixed ||
!BufferContainsAxis(buffer, sp_iter->axis)) {
continue;
}

auto it = dependency_map_.find(sp_iter);
if (it == dependency_map_.end()) {
dependency_map_[sp_iter] = std::make_pair(buffer, k);
} else {
const Array<SpIterVar>& dependent_iters = buffer_access_map_[it->second.first];
for (int i = 0; i < k; ++i) {
CHECK(sp_iters[i].same_as(dependent_iters[i]))
<< "ValueError: A SpIterVar can only depend on a fixed set of "
"iterators";
}
}
}
}
}

/*!
* \brief Collect the dependent buffer and iterators current sparse iterator depends on.
* \param sp_iter The sparse iterator.
* \param iterated_buffer The sparse buffer that given sparse iterator depends on.
* \param dependent_iters The sparse iterators that given sparse iterator depends on in the
* program.
* \note iterated_buffer and dependent_iters were pointers used as return values.
*/
void GetIteratedBufferAndDependentIters(const SpIterVar& sp_iter, SparseBuffer* iterated_buffer,
Array<PrimExpr>* dependent_iters) {
SparseBuffer dependent_buf;
int n_dependent;
std::tie(dependent_buf, n_dependent) = dependency_map_[sp_iter];
Array<SpIterVar> buffer_access_iters = buffer_access_map_[dependent_buf];

*iterated_buffer = std::move(dependent_buf);
*dependent_iters = Array<PrimExpr>();
dependent_iters->reserve(n_dependent);
for (int i = 0; i < n_dependent; ++i) {
dependent_iters->push_back(buffer_access_iters[i]->var);
}
}

/*!
* \brief Get sparse iterator corresponding to the given variable.
* \param index The variable
*/
SpIterVar GetSpIterFromIndex(PrimExpr index) {
auto it = var_sp_iter_map_.find(index.as<VarNode>());
CHECK(it != var_sp_iter_map_.end())
<< "ValueError: Currently an index is only allowed to be SpIterVar";
return it->second;
}

private:
/*!
* \brief Update the buffer access map given a sparse buffer access pattern.
* \param buffer The buffer to be accessed.
* \param indices The indices used to access the sparse buffer.
* \note We don't support use two set of indices to access the same buffer, and will throw
* an error in this case. For example, we can not access sparse buffer A with A[i, j]
* and A[j, i] in the same program.
* TODO(zihao, ruihang): fix the behavior in the future.
*/
void AddAccessPattern(const SparseBuffer& buffer, const Array<PrimExpr>& indices) {
int ndim = buffer->ndim();
CHECK_EQ(static_cast<int>(indices.size()), ndim);

Array<SpIterVar> iters;
iters.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
iters.push_back(GetSpIterFromIndex(indices[i]));
}

BufferAccessMap::iterator it = buffer_access_map_.find(buffer);
if (it == buffer_access_map_.end()) {
buffer_access_map_.Set(buffer, iters);
} else {
ICHECK_EQ(static_cast<int>((*it).second.size()), ndim);
for (int i = 0; i < ndim; ++i) {
CHECK((*it).second[i].same_as(iters[i]))
<< "ValueError: Currently all accesses to a same buffer are "
"required to be the same";
}
}
}

/*!
* \brief The visit function to collect variable to sparse iterator mapping for sparse block node.
* \param sp_block The sparse block node in AST.
*/
void VisitStmt_(const SparseBlockNode* sp_block) final {
for (const SpIterVar& sp_iter : sp_block->sp_iter_vars) {
var_sp_iter_map_[sp_iter->var.get()] = sp_iter;
}
StmtVisitor::VisitStmt_(sp_block);
}

/*!
* \brief The visit function to collect buffer access pattern from sparse buffer stores.
* \param store The sparse buffer store node in AST.
*/
void VisitStmt_(const SparseBufferStoreNode* store) final {
ExprVisitor::VisitExpr(store->value);
AddAccessPattern(store->buffer, store->indices);
}

/*!
* \brief The visit function to collect buffer access pattern from sparse buffer loads.
* \param load The sparse buffer load node in AST.
*/
void VisitExpr_(const SparseBufferLoadNode* load) final {
AddAccessPattern(load->buffer, load->indices);
}

BufferAccessMap buffer_access_map_;
DependencyMap dependency_map_;
std::unordered_map<const VarNode*, SpIterVar> var_sp_iter_map_;
};

/*!
* \brief Check whether the new order satisfies the iterator dependency constraints
* \param self The schedule state
* \param block The sparse block, which is the source of the constraints
* \param new_order The new iterator order to be checked
*/
void CheckDependency(const ScheduleState& self, const SparseBlock& block,
const Array<SpIterVar>& new_order);

} // namespace tir
} // namespace tvm

Expand Down
62 changes: 0 additions & 62 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1377,68 +1377,6 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) {
return GetRef<StmtSRef>(p);
}

/******** SparseTIR Tools ********/

bool BufferContainsAxis(const SparseBuffer& buffer, const Axis& axis) {
for (int i = 0; i < static_cast<int>(buffer->axes.size()); ++i) {
if (buffer->axes[i].same_as(axis)) {
return true;
}
}
return false;
}

void CheckDependency(const ScheduleState& self, const SparseBlock& block,
const Array<SpIterVar>& new_order) {
class DependentIterNotAppearError : public ScheduleError {
public:
explicit DependentIterNotAppearError(IRModule mod, SpIterVar iter, SpIterVar dependent_iter)
: mod_(std::move(mod)),
iter_(std::move(iter)),
dependent_iter_(std::move(dependent_iter)) {}

String FastErrorString() const final {
return "ScheduleError: The new order violates some iterator dependency";
}

String DetailRenderTemplate() const final {
std::ostringstream os;
os << "ScheduleError: Iterator " << iter_ << " depends on " << dependent_iter_
<< ", while the latter iterator does not appear before the former iterator in the new "
"order";
return os.str();
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {}; }

IRModule mod_;
SpIterVar iter_;
SpIterVar dependent_iter_;
};

AccessAndDependencyCollector collector;
collector.Collect(block);

for (int i = 0; i < static_cast<int>(new_order.size()); ++i) {
const SpIterVar& sp_iter = new_order[i];
if (sp_iter->kind == SpIterKind::kDenseFixed) {
continue;
}

SparseBuffer iterated_buffer{nullptr};
Array<PrimExpr> iters{nullptr};
collector.GetIteratedBufferAndDependentIters(sp_iter, &iterated_buffer, &iters);

for (const PrimExpr& index : iters) {
const SpIterVar dependent_iter = collector.GetSpIterFromIndex(index);
if (std::find(new_order.begin(), new_order.begin() + i, dependent_iter) ==
new_order.begin() + i) {
throw DependentIterNotAppearError(self->mod, sp_iter, dependent_iter);
}
}
}
}

} // namespace tir
} // namespace tvm
3 changes: 2 additions & 1 deletion src/tir/schedule/primitive/sparse_loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ SparseBlock SparseReorder(ScheduleState self, const SparseBlock& block,
CheckValidInputIterators(self, new_order, block->sp_iter_vars);

// Step 2. Check whether the new order does not break the iterator dependency.
CheckDependency(self, block, new_order);
// TODO(zihao): use axis dependency tree instead
// CheckDependency(self, block, new_order);

// Step 3. Create the new SparseBlock.
ObjectPtr<SparseBlockNode> p_new_block = make_object<SparseBlockNode>(*block.get());
Expand Down
Loading

0 comments on commit 0b0e371

Please sign in to comment.