Skip to content

Commit

Permalink
Format and Buffer data structure (apache#1)
Browse files Browse the repository at this point in the history
[SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (apache#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

[CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (apache#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <expye@outlook.com>

Fix AxisTree (apache#3)

* fix axis tree

* upd

[SparseTIR] Add SparseBufferLoad/SparseBufferStore (apache#5)

* Add dtype for SparseBuffer

* Add name for SparseBuffer. Remove `ndim`

* Remove namespace sparse

* Add SparseBufferLoad/Store

* Add method `ndim()`

[SparseTIR] Introduce SpIterVar (apache#6)

* [SparseTIR] Introduce SpIterVar

* Add conversion to PrimExpr

[BugFix] Fix binary search & SpIterVar (apache#7)

[BugFix] Add field `is_reduction` for SpIterVar (apache#9)

* [BugFix] Add field `is_reduction` for SpIterVar

* Formatting

[SparseTIR] Index Lowering (apache#8)

* Add StmtFunctor/ExprFunctor for SparseBufferStore/Load

* Add basic index lowering

* Finish index lowering (maybe)

* Address comments

* Convert CRLF to LF

Frontend update, demo scripts. (apache#10)

* Format and Buffer data structure (apache#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (apache#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (apache#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <expye@outlook.com>

* Fix AxisTree (apache#3)

* fix axis tree

* upd

* Format and Buffer data structure (apache#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (apache#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* fix axis tree

* upd

* Format and Buffer data structure (apache#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (apache#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (apache#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <expye@outlook.com>

* Fix AxisTree (apache#3)

* fix axis tree

* upd

* [SparseTIR] Add SparseBufferLoad/SparseBufferStore (apache#5)

* Add dtype for SparseBuffer

* Add name for SparseBuffer. Remove `ndim`

* Remove namespace sparse

* Add SparseBufferLoad/Store

* Add method `ndim()`

* Format and Buffer data structure (apache#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (apache#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (apache#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <expye@outlook.com>

* Fix AxisTree (apache#3)

* fix axis tree

* upd

* [SparseTIR] Add SparseBufferLoad/SparseBufferStore (apache#5)

* Add dtype for SparseBuffer

* Add name for SparseBuffer. Remove `ndim`

* Remove namespace sparse

* Add SparseBufferLoad/Store

* Add method `ndim()`

* [SparseTIR] Introduce SpIterVar (apache#6)

* [SparseTIR] Introduce SpIterVar

* Add conversion to PrimExpr

* [BugFix] Fix binary search & SpIterVar (apache#7)

* [BugFix] Add field `is_reduction` for SpIterVar (apache#9)

* [BugFix] Add field `is_reduction` for SpIterVar

* Formatting

* upd

* upd

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

[SparseTIR] SparseBlock on C++/Python side (apache#11)

* Fix a bug in the last commit

* SparseBlock on C++ & Python side

[BugFix][SparseTIR] TVMScript Parser for Axis & SpIterVar (apache#12)

* Update `cord` and `pos`

* Fix `idtype`

* Formatting..

* Bug fix 1

* Move new special stmts

* Parser for Axis and SpIterVar

* Fix context_maintainer.py

[SparseTIR] Enhance SparseBlock to contain enough PrimFunc information (apache#13)

* Enhance SparseBlock to have enough PrimFunc info

* Remove `func_sparse_buffer_map_`

* Don't print the map uh-huh

[SparseTIR] Parser, Printer, Roundtrip (apache#14)

* SparseBlock scope handler (part 1)

* SparseBlock scope handler (part 2)

* SparseBlock scope handler (part 3)

* SparseBlock scope handler (fix 1)

* Add SparseBufferLoad/Store on Python side

* Parser for SparseBufferLoad/Store

* Add SparseBlock to Python __init__

* StmtFunctor for SparseBlock

* Ensure at least one dimension for SparseBuffer

* Make `axis` field of SpIterVar mandatory

* SparseBlock scope handler (fix 2)

* Update Axis syntax by removing `name` parameter

* Move to intrin.py

* Add filed `from_sparse` to DenseFixedAxis

* SparseTIR script printer

* Roundtrip test

* `update_symbol` bug fix

* Fix attr visit in SparseBuffer

* Define then compare in SparseBlock

* Fix printer bug for SparseBuffer

* Enable graph match for Axis and SparseBuffer

* Complete HashReduce and EqualReduce for AxisTree and SparseBuffer

* Fix typo

* Rename test

* Bug fix 1

* Bug fix 2

* Add more tests

Move tests (apache#15)

[SparseTIR] ReprPrinter for Axis and SpIterVar (apache#16)

upd (apache#17)

flatten (apache#18)

ELL and BSR correctness test scripts (apache#19)

[SparseTIR] SparseTIR Lowering (apache#20)

* Fix a previous bug of sparse-fixed SpIterVar creation

* Fix a previous bug in `GetDenseValue`

* Refactor Collector and IndexTransformer

* Construct block and loops

* Fix a previous bug which rejects DV iters in collector

* Update buffer map

* Create root block

* Fix bug of sparse-fixed SpIterVar creation

* Fix bug on SpIterVar conversion (with refactor)

* Fix bug when getting dependent SpIterVars

* Fix bug on dependency map and index lowering

* Full block read/write region

* Test version 1

* Fix bug of loop order

* Fix bug of batch-mm iterator ordering

* Update PrimFunc args to use symbolic params

* Fix bug of test "csr_element_wise"

* Fix bug of index accumulation for sparse-fixed axis

* Update correctness test

* Test structural equality

* Refactor and use Array

fix nnz cols

Add docstring for sparse tir lowering (apache#21)

* add docstring

* upd

Add more examples part 1 (sddmm) (apache#22)

* upd

* upd

* upd

[SparseTIR][Schedule] SparseBlockRV, GetSparseBlock, SparseReorder (apache#23)

* Test initialization

* Fix a stupid bug of ReprPrinter

* Add SparseBlockRV

* Schedule: GetSparseBlock

* Schedule: Reorder

[SparseTIR][Schedule] GetSpIters (apache#24)

remove hybrid script for successful compilation

Add atomic intrinsic for output nonzero inference. (apache#25)

* upd

* upd

Add "sparse" block attribute. (apache#26)

Revert "remove hybrid script for successful compilation"

This reverts commit eebd7c1.

[SparseTIR] Hack `IsAffineBinding` check (apache#27)

* [TensorIR][Schedule] Inherit block anotation upon creating new blocks

* Fix SDDMM test

* Hack IsAffineBinding for sparse blocks

Axis Dependency Tree aware code-gen and bmm example (apache#28)

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* remove redundancy

* fix

* upd

* upd

Re-design Indices lowering (apache#29)

* upd

* upd

* upd

* upd

* upd

* init

* format

* fix

* revise coding-style

* format

Complete indices lowering (apache#30)

* upd

* upd

* upd

* done

* upd

* passed test

* upd

Add more docstrings and depress warnings for new lowering algorithm. (apache#31)

Refactor derived axis, frontend support of fusion. (apache#32)

* upd

* upd

* fix

Fatal bugfix and change the signature of DenseVariableAxis.  (apache#33)

Syntax simplification (apache#34)

Change the order of generated blocks for block isolation. (apache#35)

* upd

* upd

* upd

Syntax of AttachAxis for BMM (apache#36)

* upd

* upd

* upd

[SparseTIR] Add "square sum" lowering test (apache#37)

* Add square sum test

* Remove pylint comment

[BugFix] Fix offset caching in lowering (apache#38)

* Hack compact dataflow check in a dirty way

* Add two-K square sum test

* Mark skipped tests

* Fix offset saving in lowering

Fusion syntax fix + SDDMM example.  (apache#39)

Some structure change on update offsets. (apache#40)

[Refactor] SparseTIR Lowering (apache#41)

* Take out methods in Scope

* Refactor

* Refactor "match"

* Tweak scope contents

* Refactor ViewIndexInAxis

* Refactor Scope

* SDDMM tests under implementation

* Refactor block stack

* Use Map for var_map

* Extract NeedCreateNewBlock

* Simplify SpIterVarToIterVar via GetIterExtent

* Refactor NeedCreateNewBlock

* Add docstring

* Use "auto" correctly

* Minor refactor and use some move

Remove redundant analyzers (apache#42)

Support indices lowering for attach and fuse. (apache#43)

* upd

* upd

* upd

Fix irregular BMM example. (apache#44)

* upd

* upd

* upd

* upd

RGCN forward and butterfly pattern example. (apache#45)

Fused SDDMM example. (apache#46)

* upd

* wip

* fix

Fix sparse reorder after refactor (apache#47)

[Refactor] Refactor Unittest (apache#48)

* upd

* remove redundancy

[Unittest] Correctness test for benchmarking scripts (apache#49)

Bugfix and more test for axis fusion, new workload (apache#50)

* upd

* upd

upd
  • Loading branch information
yzh119 authored and MasterJH5574 committed Mar 4, 2022
1 parent 9679e68 commit 7540891
Show file tree
Hide file tree
Showing 57 changed files with 6,241 additions and 23 deletions.
15 changes: 15 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,21 @@ TVM_DLL const Op& tvm_warp_shuffle_up();
TVM_DLL const Op& tvm_warp_shuffle_down();
TVM_DLL const Op& tvm_warp_activemask();

/*!
* \brief Lower bound function for binary search.
*/
TVM_DLL const Op& tvm_lower_bound();

/*!
* \brief Upper bound function for binary search.
*/
TVM_DLL const Op& tvm_upper_bound();

/*!
* \brief Atomic add function.
*/
TVM_DLL const Op& tvm_atomic_add();

/*!
* \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier.
Expand Down
53 changes: 53 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/sparse.h>
#include <tvm/tir/var.h>

#include <algorithm>
Expand Down Expand Up @@ -643,6 +644,58 @@ class BufferLoad : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};

/*!
* \brief Load value from the high dimension sparse buffer.
*
* \code
*
* value = buffer[i, j];
*
* \endcode
* \sa SparseBufferStore
*/
class SparseBufferLoadNode : public PrimExprNode {
public:
/*! \brief The buffer to be loaded. */
SparseBuffer buffer;
/*! \brief The indices location to be loaded. */
Array<PrimExpr> indices;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("buffer", &buffer);
v->Visit("indices", &indices);
v->Visit("span", &span);
}

bool SEqualReduce(const SparseBufferLoadNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(buffer);
hash_reduce(indices);
}

static constexpr const char* _type_key = "tir.SparseBufferLoad";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferLoadNode, PrimExprNode);
};

/*!
* \brief Managed reference to SparseBufferLoadNode.
* \sa SparseBufferLoadNode
*/
class SparseBufferLoad : public PrimExpr {
public:
TVM_DLL explicit SparseBufferLoad(SparseBuffer buffer, Array<PrimExpr> indices,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferLoad, PrimExpr, SparseBufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferLoadNode);
};

/*!
* \brief Load value from the result produced by the producer.
*
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
}
virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const SparseBufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -165,6 +166,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(SparseBufferLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
IR_EXPR_FUNCTOR_DISPATCH(CallNode);
Expand Down Expand Up @@ -217,6 +219,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
void VisitExpr_(const SizeVarNode* op) override;
void VisitExpr_(const LoadNode* op) override;
void VisitExpr_(const BufferLoadNode* op) override;
void VisitExpr_(const SparseBufferLoadNode* op) override;
void VisitExpr_(const ProducerLoadNode* op) override;
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const CallNode* op) override;
Expand Down Expand Up @@ -264,6 +267,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
PrimExpr VisitExpr_(const SizeVarNode* op) override;
PrimExpr VisitExpr_(const LoadNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
PrimExpr VisitExpr_(const SparseBufferLoadNode* op) override;
PrimExpr VisitExpr_(const ProducerLoadNode* op) override;
PrimExpr VisitExpr_(const LetNode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
Expand Down
32 changes: 32 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,38 @@ TVM_DLL PrimExpr round(PrimExpr x, Span span = Span());
*/
TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span());

/*!
* \brief Lower bound function for binary search
* \param arr The buffer variable of the array to be looked up in
* \param val The value to be looked up in the array
* \param l The left boundary of the look-up range (inclusive)
* \param r The right boundary of the look-up range (exclusive)
* \param span The location of this operation in the source
* \return The look-up result
*/
TVM_DLL PrimExpr lower_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r,
Span span = Span());

/*!
* \brief Upper bound function for binary search
* \param arr The buffer variable of the array to be looked up in
* \param val The value to be looked up in the array
* \param l The left boundary of the look-up range (inclusive)
* \param r The right boundary of the look-up range (exclusive)
* \param span The location of this operation in the source
* \return The look-up result
*/
TVM_DLL PrimExpr upper_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r,
Span span = Span());

/*!
* \brief Perform atomic add on ptr by val, and return the old value.
* \param ptr The address to perform atomic add.
* \param val The value to add.
* \return The old result stored in ptr.
*/
TVM_DLL PrimExpr atomic_add(tir::Var ptr, PrimExpr val, Span span = Span());

/*!
* \brief Calculate trunc(x)
* \param x The input expression.
Expand Down
56 changes: 56 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>
#include <tvm/tir/sparse.h>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -85,6 +86,27 @@ using ExprRV = PrimExpr;

using ExprRVNode = PrimExprNode;

/**************** Random variable: SparseBlockRV ****************/

/*! \brief A random variable that evaluates to a TensorIR sparse block */
class SparseBlockRVNode : public runtime::Object {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "tir.SparseBlockRV";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockRVNode, runtime::Object);
};

/*!
* \brief Managed reference to SparseBlockRVNode
* \sa SparseBlockRVNode
*/
class SparseBlockRV : public runtime::ObjectRef {
public:
/*! \brief Constructor */
TVM_DLL SparseBlockRV();
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SparseBlockRV, runtime::ObjectRef, SparseBlockRVNode);
};

/**************** The Schedule class ****************/

class Schedule;
Expand Down Expand Up @@ -143,6 +165,12 @@ class ScheduleNode : public runtime::Object {
* \return The corresponding expr
*/
virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
/*!
* \brief Get the sparse block corresponding to the specific random variable
* \param sp_block_rv The random variable to be looked up
* \return SparseBlock The corresponding sparse block
*/
virtual SparseBlock Get(const SparseBlockRV& sp_block_rv) const = 0;
/*!
* \brief Get the block sref corresponding to the specific BlockRV
* \param block_rv The BlockRV to be looked up
Expand Down Expand Up @@ -188,6 +216,11 @@ class ScheduleNode : public runtime::Object {
* \param expr_rv The random variable to be removed
*/
virtual void RemoveRV(const ExprRV& expr_rv) = 0;
/*!
* \brief Remove an sparse block random variable from the symbol table
* \param sp_block_rv The random variable to be removed
*/
virtual void RemoveRV(const SparseBlockRV& sp_block_rv) = 0;

public:
/******** Schedule: Sampling ********/
Expand Down Expand Up @@ -524,6 +557,29 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
/******** Schedule: SparseTIR schedules ********/
/*!
* \brief Retrieve a sparse block in a specific function with its name
* \param name The name of the sparse block to be retrieved
* \param func_name The name of the function
* \return The sparse block retrieved
* \note Indexing error is raised if 0 or multiple blocks exist with the specific name
*/
virtual SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") = 0;
/*!
* \brief Retrieve the sparse iterators of a given sparse block
* \param block_rv The block to be queried
* \return The sparse iterators of the input sparse block
*/
virtual Array<SpIterVar> GetSpIters(const SparseBlockRV& block_rv) = 0;
/*!
* \brief Reorder a list of sparse iterators. It requires the new order to not break the iterator
* dependency.
* \param block The block to be transformed
* \param new_order The new order of the sparse iterators, whose length should equal to the number
* of the input block's sparse iterators
*/
virtual void SparseReorder(const SparseBlockRV& block_rv, const Array<SpIterVar>& new_order) = 0;
};

/*!
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ class ScheduleStateNode : public Object {
* \return A boolean flag indicating if the block has quasi-affine bindings
*/
bool IsAffineBlockBinding(const StmtSRef& block_sref) const {
// (SparseTIR Hack) Always return true for sparse blocks.
const auto* block = block_sref->StmtAs<BlockNode>();
Optional<ObjectRef> sparse_attr = block != nullptr ? block->annotations.Get("sparse") : NullOpt;
if (sparse_attr.defined() && sparse_attr.as<IntImmNode>()->value == 1) {
return true;
}

return GetBlockInfo(block_sref).affine_binding;
}
/*!
Expand Down
Loading

0 comments on commit 7540891

Please sign in to comment.