Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][REFACTOR] Migrate low-level passes in tvm.lower to the Pass Manager #5364

Merged
merged 1 commit into from
Apr 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ struct ExprDeepEqual {
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};


/*!
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
Expand Down
140 changes: 0 additions & 140 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,144 +202,13 @@ Stmt RewriteForTensorCore(Stmt stmt,
*/
bool VerifyCompactBuffer(Stmt stmt);

/*!
* \brief Remove No Op from the Stmt.
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt RemoveNoOp(Stmt stmt);

/*!
* \brief unroll the constant loop marked by unroll.
* This pass also automatically attach pragma unroll tag to loops which meets the standard.
*
* \param stmt The statment to be unrolled.
* \param auto_max_step The maximum step before stop attach automatic unroll
* \param auto_max_depth The maximum depth before stop attach automatic unroll
* \param auto_max_extent The maximum extent of the loop we can unroll,
* this is an legacy option that do not take the loop total steps into account.
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return Transformed stmt.
*/
Stmt UnrollLoop(Stmt stmt,
int auto_max_step,
int auto_max_depth,
int auto_max_extent,
bool explicit_unroll);

/*!
* \brief vectorize the constant loops
* \param stmt The statement to be vectorized.
* \return Transformed stmt.
*/
Stmt VectorizeLoop(Stmt stmt);

/*!
* \brief convert vectorized loops into serialized loops
* \param stmt The statement to skip vectorization on.
* \return Transformed stmt.
*/
Stmt SkipVectorize(Stmt stmt);

/*!
* \brief instruments bound checkers.
* \param stmt The statement to be instrumented.
* \return Instrumented stmt.
*/
Stmt InstrumentBoundCheckers(Stmt stmt);

/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectVirtualThread(Stmt stmt);

/*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectPrefetch(Stmt stmt);

/*!
* \brief Inject double buffer into stmt.
* \param stmt The statement to be transformed.
* \param split_loop Loop splitting factor.
* \return Transformed stmt.
*/
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);

/*!
* \brief Inject copy intrinsics with optional pad.
*
* \param stmt The statement to be transformed.
* \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature
*
* Stmt fintrin(Buffer src,
* Buffer dst,
* Array<Expr> pad_before,
* Array<Expr> pad_after,
* Expr pad_value)
* \return Transformed stmt.
*/
Stmt InjectCopyIntrin(Stmt stmt,
const std::string& pragma_key,
const runtime::PackedFunc& fintrin);

/*!
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make
* a static allocation plan when possible.
*
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt StorageRewrite(Stmt stmt);

/*!
* \brief partition loops in the stmt
* \param stmt The stmt to do loop partition
* \param split_const_loop flag to enable partition for const loop
* \return Transformed stmt.
*/
Stmt LoopPartition(Stmt stmt, bool split_const_loop);

/*!
* \brief Detect and insert sync points to co-processor.
*
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt CoProcSync(Stmt stmt);

/*!
* \brief Lift common attrs with attr_key to outer scope.
*
* \param stmt The stmt to be transformed
* \param attr_key The attribute key to be checked.
* \return Transformed stmt.
*/
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);

/*!
* \brief Detect and rewrite unsafe select that contains memory access.
* \param stmt The statement to be rewritten.
* \return Transformed stmt.
*/
Stmt RewriteUnsafeSelect(Stmt stmt);

/*!
* \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish.
*
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt LowerStorageAccessInfo(Stmt stmt);

/*!
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
Expand All @@ -356,15 +225,6 @@ Stmt DecorateDeviceScope(Stmt stmt);
*/
Stmt HoistIfThenElse(Stmt stmt);

/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
* \note Run this pass after StorageFlatten.
* \param stmt The stmt to do datatype rewrite
* \param target_bits the bit of target datatype
* \return Transformed stmt.
*/
Stmt NarrowDataType(Stmt stmt, int target_bits);

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand Down
118 changes: 118 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,124 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const std::string& name,
const tvm::Array<runtime::String>& required);

/*!
* \brief Inject copy intrinsics with optional pad.
*
* \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature
*
* Stmt fintrin(Buffer src,
* Buffer dst,
* Array<Expr> pad_before,
* Array<Expr> pad_after,
* Expr pad_value)
* \return The pass.
*/
TVM_DLL Pass InjectCopyIntrin(std::string pragma_key,
runtime::PackedFunc fintrin);

/*!
* \brief Detect and insert sync points to co-processor.
*
* \return The pass.
*/
TVM_DLL Pass CoProcSync();

/*!
* \brief Lift common attrs with attr_key to outer scope.
*
* \param attr_key The attribute key to be checked.
* \return The pass.
*/
TVM_DLL Pass LiftAttrScope(std::string attr_key);

/*!
* \brief partition loops in the stmt.
*
* \param split_const_loop flag to enable partition for const loop
*
* \return The pass.
*/
TVM_DLL Pass LoopPartition(bool split_const_loop);

/*!
* \brief Lower vectorization loops.
*
* \param enable_vectorize Whether vectorization is enabled.
*
* \return The pass.
*/
TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);

/*!
* \brief Inject virtual thread loops.
*
* \return The pass.
*/
TVM_DLL Pass InjectVirtualThread();

/*!
* \brief Inject double buffer statements.
*
* \param split_loop_factor Loop splitting factor.
* \return The pass.
*/
TVM_DLL Pass InjectDoubleBuffer(int split_loop_factor);

/*!
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make
* a static allocation plan when possible.
*
* \return The pass.
*/
TVM_DLL Pass StorageRewrite();

/*!
* \brief unroll the constant loop marked by unroll.
* This pass also automatically attach pragma unroll tag to loops which meets the standard.
*
* \param auto_max_step The maximum step before stop attach automatic unroll
* \param auto_max_depth The maximum depth before stop attach automatic unroll
* \param auto_max_extent The maximum extent of the loop we can unroll,
* this is an legacy option that do not take the loop total steps into account.
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return The pass.
*/
TVM_DLL Pass UnrollLoop(int auto_max_step,
int auto_max_depth,
int auto_max_extent,
bool explicit_unroll);

/*!
* \brief Remove No Op from the Stmt.
*
* \return The pass.
*/
TVM_DLL Pass RemoveNoOp();

/*!
* \brief Detect and rewrite unsafe select that contains memory access.
*
* \return The pass.
*/
TVM_DLL Pass RewriteUnsafeSelect();

/*!
* \brief Run arithmetic simplifications on the statements and expressions.
*
* \return The pass.
*/
TVM_DLL Pass Simplify();

/*!
* \brief Instruments bound checkers.
*
* \return The pass.
*/
TVM_DLL Pass InstrumentBoundCheckers();

/*!
* \brief Transform the high-level PrimFunc to a low-level version
* that can be used as an API function.
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def lower(sch,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit)

for f in lower_phase2:
stmt = f(stmt)

Expand All @@ -187,11 +188,14 @@ def lower(sch,
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)

for f in lower_phase3:
stmt = f(stmt)

# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)

if simple_mode:
return stmt

Expand Down