diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 8033294a0f99..c4ee7b5b6279 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -128,17 +128,17 @@ class ConstIntBoundAnalyzer { * * \param var The variable of interest. * \param info The bound information. - * \param override Whether do we allow override of existing information. + * \param allow_override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool override = false); + TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false); /*! * \brief Bind variable to a range. * * \param var The variable. * \param range The range we bind to. - * \param override Whether we allow overriding an existing var's range. + * \param allow_override Whether we allow overriding an existing var's range. */ - TVM_DLL void Bind(const Var& var, const Range& range, bool override = false); + TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false); private: friend class Analyzer; @@ -217,9 +217,9 @@ class ModularSetAnalyzer { * * \param var The variable of interest. * \param info The bound information. - * \param override Whether do we allow override of existing information. + * \param allow_override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, const ModularSet& info, bool override = false); + TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false); private: friend class Analyzer; @@ -256,9 +256,9 @@ class RewriteSimplifier { * * \param var The variable of interest. * \param new_expr - * \param override Whether do we allow override of existing information. + * \param allow_override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false); + TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false); std::function EnterConstraint(const PrimExpr& constraint); @@ -290,9 +290,9 @@ class CanonicalSimplifier { * * \param var The variable of interest. * \param new_expr - * \param override Whether do we allow override of existing information. + * \param allow_override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false); + TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false); private: friend class Analyzer; @@ -404,9 +404,9 @@ class TVM_DLL Analyzer { * * \param var The variable. * \param expr The expression we bind to. - * \param override Whether we allow overriding an existing var's expression. + * \param allow_override Whether we allow overriding an existing var's expression. */ - void Bind(const Var& var, const PrimExpr& expr, bool override = false); + void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false); /*! * \brief Notify all the sub-analyzers that var * is created and binded to a range. @@ -415,16 +415,16 @@ class TVM_DLL Analyzer { * * \param var The variable. * \param range The range we bind to. - * \param override Whether we allow overriding an existing var's expression. + * \param allow_override Whether we allow overriding an existing var's expression. */ - void Bind(const Var& var, const Range& range, bool override = false); + void Bind(const Var& var, const Range& range, bool allow_override = false); /*! * \brief Bind all the vars in the Map * * \param variables The {variable -> range} map. - * \param override Whether we allow overriding an existing var's expression. + * \param allow_override Whether we allow overriding an existing var's expression. */ - void Bind(const Map& variables, bool override = false); + void Bind(const Map& variables, bool allow_override = false); /*! * \brief Whether can we prove expr >= val. diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 34cb52f90211..31ce13c7e66a 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -671,11 +671,18 @@ inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); } inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); } /*! - * \brief Check whether x is a constant. + * \brief Check whether x is an integer constant. * \note This only return true for integer types. * \return whether x is constant */ -inline bool is_const(const PrimExpr& x); +inline bool is_const_int(const PrimExpr& x); + +/*! + * \brief Check whether x is an integer/float constant. + * \note This only return true for integer types. + * \return whether x is constant + */ +inline bool is_const_number(const PrimExpr& x); /*! * \brief Left fold. @@ -699,7 +706,7 @@ inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array()) { return true; } else if (const auto* op = x.as()) { @@ -711,6 +718,17 @@ inline bool is_const(const PrimExpr& x) { return false; } +inline bool is_const_number(const PrimExpr& x) { + if (x.as()) { + return true; + } else if (x.as()) { + return true; + } else if (const auto* op = x.as()) { + return (op->value->IsInstance() || op->value->IsInstance()); + } + return false; +} + inline bool is_positive_const(const PrimExpr& a) { if (const tir::IntImmNode* op = a.as()) { return op->value > 0; @@ -742,7 +760,7 @@ inline bool is_const_int(const PrimExpr& x, int64_t value) { inline bool is_no_op(const tir::Stmt& stmt) { if (!stmt.defined()) return true; if (const auto* op = stmt.as()) { - return is_const(op->value); + return is_const_int(op->value); } if (const auto* op = stmt.as()) { return op->seq.size() == 0; diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 037c76665d4b..c7a8365b9fda 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -35,31 +35,31 @@ Analyzer::Analyzer() canonical_simplify(this), int_set(this) {} -void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) { +void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { PrimExpr new_expr = expr; new_expr = this->canonical_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr); - this->const_int_bound.Update(var, this->const_int_bound(new_expr), override); - this->modular_set.Update(var, this->modular_set(new_expr), override); - this->rewrite_simplify.Update(var, new_expr, override); - this->canonical_simplify.Update(var, new_expr, override); + this->const_int_bound.Update(var, this->const_int_bound(new_expr), allow_override); + this->modular_set.Update(var, this->modular_set(new_expr), allow_override); + this->rewrite_simplify.Update(var, new_expr, allow_override); + this->canonical_simplify.Update(var, new_expr, allow_override); } -void Analyzer::Bind(const Var& var, const Range& range, bool override) { +void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { CHECK(range.defined()); if (tir::is_one(range->extent)) { - this->Bind(var, range->min, override); + this->Bind(var, range->min, allow_override); } else { - this->const_int_bound.Bind(var, range, override); + this->const_int_bound.Bind(var, range, allow_override); } // skip modular_set // skip rewrite simplify } -void Analyzer::Bind(const Map& variables, bool override) { +void Analyzer::Bind(const Map& variables, bool allow_override) { for (const auto& iter : variables) { - this->Bind(iter.first, iter.second, override); + this->Bind(iter.first, iter.second, allow_override); } } @@ -116,9 +116,9 @@ bool Analyzer::CanProve(const PrimExpr& expr) { } PrimExpr Analyzer::Simplify(const PrimExpr& expr) { - if (tir::is_const(expr)) return expr; + if (tir::is_const_int(expr)) return expr; auto res = this->rewrite_simplify(expr); - if (tir::is_const(res)) return res; + if (tir::is_const_int(res)) return res; res = this->canonical_simplify(res); return res; } diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 8c90249f4f17..be830d389209 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -96,17 +96,17 @@ class ConstIntBoundAnalyzer::Impl BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {} }; - void Bind(const Var& var, const Range& range, bool override) { + void Bind(const Var& var, const Range& range, bool allow_override) { Entry a = VisitExpr(range->min); Entry b = VisitExpr(range->extent); Entry ret; ret.min_value = a.min_value; ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1)); - Update(var, ret, override); + Update(var, ret, allow_override); } - void Update(const Var& var, const Entry& info, bool override) { - if (!override) { + void Update(const Var& var, const Entry& info, bool allow_override) { + if (!allow_override) { auto it = var_map_.find(var); if (it != var_map_.end()) { CHECK(it->second == info) << "Trying to update var \'" << var << "\'" @@ -119,8 +119,21 @@ class ConstIntBoundAnalyzer::Impl var_map_[var] = info; } - void Update(const Var& var, const ConstIntBound& info, bool override) { - Update(var, MakeBound(info->min_value, info->max_value), override); + Entry VisitExpr_(const LetNode* op) final { + auto it = var_map_.find(op->var); + // if the var has not been binded, update the info. + if (it == var_map_.end()) { + var_map_[op->var] = this->VisitExpr(op->value); + Entry ret = VisitExpr(op->body); + var_map_.erase(op->var); + return ret; + } else { + return VisitExpr(op->body); + } + } + + void Update(const Var& var, const ConstIntBound& info, bool allow_override) { + Update(var, MakeBound(info->min_value, info->max_value), allow_override); } // Override visitor behaviors @@ -558,12 +571,12 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, BoundMapTy return ConstIntBound(ret.min_value, ret.max_value); } -void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool override) { - impl_->Update(var, info, override); +void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool allow_override) { + impl_->Update(var, info, allow_override); } -void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) { - impl_->Bind(var, range, override); +void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { + impl_->Bind(var, range, allow_override); } std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) { diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 108f08c4f78f..8c4176085896 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -89,8 +89,8 @@ class ModularSetAnalyzer::Impl : public ExprFunctorsecond == info) << "Trying to update var \'" << var << "\'" @@ -118,6 +118,19 @@ class ModularSetAnalyzer::Impl : public ExprFunctorvar); + // if the var has not been binded, update the info. + if (it == var_map_.end()) { + var_map_[op->var] = this->VisitExpr(op->value); + Entry ret = VisitExpr(op->body); + var_map_.erase(op->var); + return ret; + } else { + return VisitExpr(op->body); + } + } + Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); } @@ -315,8 +328,8 @@ ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) { return ModularSet(ret.coeff, ret.base); } -void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool override) { - impl_->Update(var, info, override); +void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool allow_override) { + impl_->Update(var, info, allow_override); } std::function ModularSetAnalyzer::EnterConstraint(const PrimExpr& constraint) { diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 898eecc93845..e9d640ad660f 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1519,7 +1519,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { op = ret.as(); if (op == nullptr) return ret; - if (op->op.same_as(tir::builtin::likely()) && is_const(op->args[0])) { + if (op->op.same_as(tir::builtin::likely()) && is_const_int(op->args[0])) { return op->args[0]; } else if (op->op.same_as(tir::builtin::shift_right())) { if (op->args[0].as() && op->args[1].as()) { @@ -1559,9 +1559,17 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) { return cast(op->dtype, op->value); } +bool RewriteSimplifier::Impl::CanInlineLet(const LetNode* op) { + // Only inline trivial bindings to avoid deep expression explosion + // when we need let to construct complicated expressions. + if (is_const_number(op->value)) return true; + if (op->value.as()) return true; + return false; +} + PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); - if (!tir::HasSideEffect(value)) { + if (CanInlineLet(op)) { // it is fine to discard the let binding // because the value will always be inlined in the simplifier. analyzer_->Bind(op->var, value); @@ -1587,8 +1595,8 @@ PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) { return res; } -void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool override) { - impl_->Update(var, info, override); +void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool allow_override) { + impl_->Update(var, info, allow_override); } std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) { diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 68c0dd271410..258f833a7b21 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -98,6 +98,13 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { */ CompareResult TryCompare(const PrimExpr& x, int64_t val); + /*! + * \brief Internal function to check whether or not to inline let. + * \param op The let expr. + * \return The inline decision. + */ + bool CanInlineLet(const LetNode* op); + private: // Whether x >= val bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) { diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index b65ae91c6393..67765f039714 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -381,7 +381,7 @@ void CodeGenHybrid::VisitStmt_(const ForNode* op) { bool is_noop(const Stmt& stmt) { if (!stmt.defined()) return true; - if (auto eval = stmt.as()) return is_const(eval->value); + if (auto eval = stmt.as()) return is_const_int(eval->value); return false; } @@ -409,7 +409,7 @@ void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) { } void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) { - if (is_const(op->value)) return; + if (is_const_int(op->value)) return; std::string str = PrintExpr(op->value); if (!str.empty()) stream << str << "\n"; } diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 7ab26fae785f..ca038abab819 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -71,12 +71,14 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) { } } -Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) { - const auto* op = primFunc.operator->(); +Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { + const auto* op = prim_func.operator->(); const auto& signature = op->func_type_annotation(); // collect Meta in DictAttr - for (const auto& it : primFunc->attrs->dict) { - meta_collector_.Collect(it.second); + if (prim_func->attrs.defined()) { + for (const auto& it : prim_func->attrs->dict) { + meta_collector_.Collect(it.second); + } } // collect buffers in buffer_map memo_var_.clear(); @@ -100,46 +102,54 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) { // print attr Doc attr_doc; std::vector attr_docs; - for (const auto& it : op->attrs->dict) { - attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); + if (prim_func->attrs.defined()) { + for (const auto& it : op->attrs->dict) { + attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); + } + attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}"; + doc << Doc::Indent(2, attr_doc); } - attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}"; - doc << Doc::Indent(2, attr_doc); + // print all the buffers in the tree - Doc buffer_doc; - std::vector buffer_docs; - for (const auto& it : memo_buf_) { - const auto& buf = it.first; - buffer_docs.push_back(Print(buf) << Doc::Text(": Buffer(") << Print(buf->data) << ", " - << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", " - << Print(buf->strides)); - if (!is_zero(buf->elem_offset)) { - buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset); - } - if (buf->scope != "global") { - buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope); - } - if (buf->data_alignment != 128) { - buffer_docs.back() << ", align=" << buf->data_alignment; + if (memo_buf_.size() != 0) { + Doc buffer_doc; + std::vector buffer_docs; + for (const auto& it : memo_buf_) { + const auto& buf = it.first; + buffer_docs.push_back(Print(buf) << Doc::Text(": Buffer(") << Print(buf->data) << ", " + << PrintDType(buf->dtype) << ", " << Print(buf->shape) + << ", " << Print(buf->strides)); + if (!is_zero(buf->elem_offset)) { + buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset); + } + if (buf->scope != "global") { + buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope); + } + if (buf->data_alignment != 128) { + buffer_docs.back() << ", align=" << buf->data_alignment; + } + if (buf->offset_factor != 1) { + buffer_docs.back() << ", offset_factor=" << buf->offset_factor; + } + if (buf->buffer_type != 1) { + buffer_docs.back() << ", type=" << Doc::StrLiteral("auto"); + } + buffer_docs.back() << ")"; } - if (buf->offset_factor != 1) { - buffer_docs.back() << ", offset_factor=" << buf->offset_factor; - } - if (buf->buffer_type != 1) { - buffer_docs.back() << ", type=" << Doc::StrLiteral("auto"); - } - buffer_docs.back() << ")"; + buffer_doc << Doc::NewLine() << "buffers = {"; + buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine())); + doc << Doc::Indent(2, buffer_doc) << "}"; } - buffer_doc << Doc::NewLine() << "buffers = {"; - buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine())); - doc << Doc::Indent(2, buffer_doc) << "}"; - // print buffer_map - std::vector buffer_map_doc; - for (const auto& it : op->buffer_map) { - buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second)); + + if (op->buffer_map.size() != 0) { + // print buffer_map + std::vector buffer_map_doc; + for (const auto& it : op->buffer_map) { + buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second)); + } + doc << Doc::Indent( + 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); } - doc << Doc::Indent( - 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); doc << PrintBody(op->body); return doc; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 05582fb07d6a..7c3c8309e115 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -914,7 +914,7 @@ void CodeGenC::VisitStmt_(const SeqStmtNode* op) { } void CodeGenC::VisitStmt_(const EvaluateNode* op) { - if (is_const(op->value)) return; + if (is_const_int(op->value)) return; const CallNode* call = op->value.as(); if (call) { if (call->op.same_as(builtin::tvm_storage_sync())) { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index ae5e40acd8f5..7dc63d4ac949 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -609,7 +609,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { } void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { - if (is_const(op->value)) return; + if (is_const_int(op->value)) return; const CallNode* call = op->value.as(); if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) { PrintIndent(); diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 84b14925877a..9cad92dfdacc 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -429,7 +429,7 @@ void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) { } void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) { - if (is_const(ev->value)) return; + if (is_const_int(ev->value)) return; const CallNode* op = ev->value.as(); if (op && op->op.same_as(builtin::tvm_struct_set())) { CHECK_EQ(op->args.size(), 4U); diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index c57cbf7d0703..834ad09cb61a 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -35,44 +35,60 @@ namespace tvm { namespace tir { -class IRVerifySSA final : public StmtExprVisitor { +class SSAVerifier final : public StmtExprVisitor { public: - bool is_ssa{true}; + bool is_ssa_{true}; void VisitExpr(const PrimExpr& n) final { - if (!is_ssa) return; + if (!is_ssa_) return; StmtExprVisitor::VisitExpr(n); } void VisitStmt(const Stmt& n) final { - if (!is_ssa) return; + if (!is_ssa_) return; StmtExprVisitor::VisitStmt(n); } void VisitExpr_(const LetNode* op) final { - MarkDef(op->var.get()); + // Weaker SSA condition + // A single var can be binded in multiple lets + // but they have to bind to the same value. + // This is used to enable cases when we reuse a single let + // expression to cosntruct a nested expr. + // (let x = 1 in x + 1) * (let x = 1 in x + 1) + auto it = def_map_.find(op->var); + if (it != def_map_.end()) { + if (!deep_equal_(it->second, op->value)) { + is_ssa_ = false; + return; + } + } else { + MarkDef(op->var, op->value); + } StmtExprVisitor::VisitExpr_(op); } + void VisitStmt_(const LetStmtNode* op) final { - MarkDef(op->var.get()); + MarkDef(op->var, op->value); StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const ForNode* op) final { - MarkDef(op->loop_var.get()); + MarkDef(op->loop_var, op->loop_var); StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const AllocateNode* op) final { - MarkDef(op->buffer_var.get()); + MarkDef(op->buffer_var, op->buffer_var); StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const VarNode* node) final { + auto var = GetRef(node); if (match_scope_) { - MarkDef(node, true); + MarkDef(var, var, true); } } void Run(const PrimFunc& func) { for (auto param : func->params) { - MarkDef(param.get()); + MarkDef(param, param); } for (auto kv : func->buffer_map) { @@ -99,25 +115,28 @@ class IRVerifySSA final : public StmtExprVisitor { } private: - void MarkDef(const VarNode* v, bool allow_dup = false) { - if (defined_.count(v) != 0) { + void MarkDef(const Var& var, PrimExpr value, bool allow_dup = false) { + if (def_map_.count(var) != 0) { if (!allow_dup) { - is_ssa = false; + is_ssa_ = false; return; } } else { - defined_[v] = 1; + def_map_[var] = value; } } // whether we are in match scope, where a var can occur multiple times. bool match_scope_{false}; - std::unordered_map defined_; + // deep equal + ExprDeepEqual deep_equal_; + // def map, for let, maps to the bind value, for others maps to self. + std::unordered_map def_map_; }; bool VerifySSA(const PrimFunc& func) { - IRVerifySSA visitor; + SSAVerifier visitor; visitor.Run(func); - return visitor.is_ssa; + return visitor.is_ssa_; } TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 0f67126be3e2..a0ba8d655232 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -395,7 +395,7 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) // likely PrimExpr likely(PrimExpr cond) { - if (is_const(cond)) return cond; + if (is_const_int(cond)) return cond; return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}); } diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 2fb8003486f1..1876dfe575f6 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -96,7 +96,7 @@ class CandidateSelector final : public StmtExprVisitor { void VisitStmt_(const ForNode* op) final { // partition const loop when sets partition_const_loop_ - if (!is_const(op->min) || !is_const(op->extent) || partition_const_loop_) { + if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) { const VarNode* var = op->loop_var.get(); record_.insert({var, false}); StmtExprVisitor::VisitStmt_(op); @@ -115,7 +115,7 @@ class CandidateSelector final : public StmtExprVisitor { CHECK(iv); Var var = iv->var; runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); - if ((scope.rank == 0) && (!is_const(op->value) || partition_const_loop_)) { + if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) { record_.insert({var.get(), false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var.get()) && !no_split_) { diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 3be232964f36..3c8a9344d87f 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -54,9 +54,18 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return Parent::VisitStmt_(op); } + bool CanInlineLetStmt(const LetStmtNode* op) { + if (is_const_number(op->value)) return true; + if (op->value.as()) return true; + // Won't face the deep expression explosion problem as in Let expression. + // attempt to inline as much as possible if the value integer type(can be index). + if (!op->value.dtype().is_int()) return false; + return !tir::HasSideEffect(op->value); + } + Stmt VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); - if (!tir::HasSideEffect(value)) { + if (CanInlineLetStmt(op)) { // it is fine to discard the let binding // because the call to simplify will always inline the var. analyzer_->Bind(op->var, value); diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index f339c565959a..75ae743f79ef 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -70,7 +70,7 @@ class VarUseDefAnalysis : public StmtExprMutator { this->HandleDef(op->var.get()); Stmt body = this->VisitStmt(op->body); // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { + if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && simplify_let_) { return body; } else { PrimExpr value = this->VisitExpr(op->value); @@ -101,7 +101,7 @@ class VarUseDefAnalysis : public StmtExprMutator { this->HandleDef(op->var.get()); PrimExpr body = this->VisitExpr(op->body); // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { + if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && simplify_let_) { return body; } else { PrimExpr value = this->VisitExpr(op->value); @@ -149,6 +149,7 @@ class VarUseDefAnalysis : public StmtExprMutator { // The fields are publically readible to // be accessible to the users. bool visit_thread_extent_{true}; + bool simplify_let_{true}; Array undefined_; Array thread_axis_; Array thread_extent_; @@ -158,6 +159,7 @@ class VarUseDefAnalysis : public StmtExprMutator { Array UndefinedVars(const Stmt& stmt, const Array& args) { VarUseDefAnalysis m; + m.simplify_let_ = false; for (Var arg : args) { m.use_count_[arg.get()] = 0; } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index e015990847e5..bf54ada6e837 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -23,6 +23,7 @@ // Loop vectorizer as in Halide pipeline. #include #include +#include #include #include #include @@ -91,15 +92,21 @@ class VecAllocAccess : public StmtExprMutator { int var_lanes_; }; -class Vectorizer : public StmtExprMutator { +// We use ExprFunctor directly instead of StmtExprMutator +// This is because the transformation can change the dtype of the Expr +// The existing ExprMutator transformation rules may not be well defined. +class Vectorizer : public StmtMutator, public ExprFunctor { public: + using ExprFunctor::VisitExpr; + using StmtMutator::operator(); + Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) { ramp_ = Ramp(0, 1, var_lanes); } Stmt VisitStmt(const Stmt& stmt) final { CHECK(!need_scalarize_); - Stmt ret = StmtExprMutator::VisitStmt(stmt); + Stmt ret = StmtMutator::VisitStmt(stmt); if (need_scalarize_) { need_scalarize_ = false; return Scalarize(stmt); @@ -108,6 +115,8 @@ class Vectorizer : public StmtExprMutator { } } + PrimExpr VisitExpr(const PrimExpr& e) final { return ExprFunctor::VisitExpr(e); } + PrimExpr VisitExpr_(const AddNode* op) final { return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; }); } @@ -151,6 +160,16 @@ class Vectorizer : public StmtExprMutator { PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec(op); } + + PrimExpr VisitExpr_(const NotNode* op) final { + PrimExpr a = this->VisitExpr(op->a); + if (a.same_as(op->a)) { + return GetRef(op); + } else { + return !(a); + } + } + PrimExpr VisitExpr_(const RampNode* op) final { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); @@ -170,6 +189,20 @@ class Vectorizer : public StmtExprMutator { } return Shuffle::Concat(elems); } + + PrimExpr VisitExpr_(const BroadcastNode* op) final { + PrimExpr value = this->VisitExpr(op->value); + if (value.dtype().lanes() != 1) { + need_scalarize_ = true; + return GetRef(op); + } + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return Broadcast(op->value, op->lanes); + } + } + PrimExpr VisitExpr_(const SelectNode* op) final { PrimExpr cond = this->VisitExpr(op->condition); PrimExpr t = this->VisitExpr(op->true_value); @@ -189,14 +222,25 @@ class Vectorizer : public StmtExprMutator { return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); } } + + PrimExpr VisitExpr_(const FloatImmNode* op) final { return GetRef(op); } + + PrimExpr VisitExpr_(const IntImmNode* op) final { return GetRef(op); } + + PrimExpr VisitExpr_(const StringImmNode* op) final { return GetRef(op); } + // Variable - PrimExpr VisitExpr_(const VarNode* v) final { - if (v == var_.get()) { + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + + if (var.same_as(var_)) { return ramp_; - } else if (lets_.count(v)) { - return lets_[v]; + } + auto it = let_binding_.find(var); + if (it != let_binding_.end()) { + return it->second; } else { - return GetRef(v); + return std::move(var); } } // IfThenElse expr @@ -267,12 +311,23 @@ class Vectorizer : public StmtExprMutator { // Let PrimExpr VisitExpr_(const LetNode* op) final { PrimExpr value = this->VisitExpr(op->value); - CHECK(!lets_.count(op->var.get())) << "not SSA"; + // Weaker SSA condition + // A single var can be binded in multiple lets + // but they have to bind to the same value. + // This is used to allow cases when we reuse a single let + // expression to cosntruct a nested expr. + // (let x = 1 in x + 1) * (let x = 1 in x + 1) + auto it = let_binding_.find(op->var); + if (it != let_binding_.end()) { + CHECK(deep_equal_(it->second, value)) + << "Let cannot bind the same var to two different values"; + } if (value.dtype().lanes() != op->value.dtype().lanes()) { - Var v(op->var->name_hint, value.dtype()); - lets_[op->var.get()] = v; - return Let(v, value, this->VisitExpr(op->body)); + Var new_var(op->var->name_hint, value.dtype()); + let_binding_[op->var] = new_var; + return Let(new_var, value, this->VisitExpr(op->body)); } else { + let_binding_[op->var] = op->var; PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); @@ -281,10 +336,6 @@ class Vectorizer : public StmtExprMutator { } } } - Stmt VisitStmt_(const ProducerStoreNode* op) final { - LOG(FATAL) << "ProducerProvide is cannot appear in a TIR PrimFunc"; - return Stmt(); - } // Store Stmt VisitStmt_(const StoreNode* op) final { PrimExpr value = this->VisitExpr(op->value); @@ -338,8 +389,23 @@ class Vectorizer : public StmtExprMutator { } // LetStmt Stmt VisitStmt_(const LetStmtNode* op) final { - LOG(WARNING) << "Cannot vectorize with LetStmt, remove it with Simplify Before Vectorize"; - return Scalarize(GetRef(op)); + PrimExpr value = this->VisitExpr(op->value); + CHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice"; + let_binding_[op->var] = value; + + if (value.dtype().lanes() != op->value.dtype().lanes()) { + Var new_var(op->var->name_hint, value.dtype()); + let_binding_[op->var] = new_var; + return LetStmt(new_var, value, this->VisitStmt(op->body)); + } else { + let_binding_[op->var] = op->var; + Stmt body = this->VisitStmt(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return GetRef(op); + } else { + return LetStmt(op->var, value, body); + } + } } // Allocate Stmt VisitStmt_(const AllocateNode* op) final { @@ -364,6 +430,7 @@ class Vectorizer : public StmtExprMutator { body = this->VisitStmt(body); return Allocate(op->buffer_var, op->dtype, extents, condition, body); } + // scalarize the statment Stmt Scalarize(Stmt stmt) { Var idx(var_->name_hint + ".s", var_->dtype); @@ -371,10 +438,17 @@ class Vectorizer : public StmtExprMutator { stmt = Substitute(stmt, values); return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); } + // ProducerStore + Stmt VisitStmt_(const ProducerStoreNode* op) final { + LOG(FATAL) << "ProducerProvide is cannot appear in a TIR PrimFunc"; + return Stmt(); + } private: // analyzer arith::Analyzer analyzer_; + // deep equal + ExprDeepEqual deep_equal_; // variable to be replaced Var var_; // the lanes. @@ -383,8 +457,8 @@ class Vectorizer : public StmtExprMutator { PrimExpr ramp_; // flag to mark requirment of scalarization. bool need_scalarize_{false}; - // The lets - std::unordered_map lets_; + // Let binding + std::unordered_map let_binding_; // vectorizable property OpAttrMap op_vectorizable_ = Op::GetAttrMap("TVectorizable"); diff --git a/tests/python/unittest/test_arith_const_int_bound.py b/tests/python/unittest/test_arith_const_int_bound.py index 4829b97c348e..c5794cd126ef 100644 --- a/tests/python/unittest/test_arith_const_int_bound.py +++ b/tests/python/unittest/test_arith_const_int_bound.py @@ -284,7 +284,16 @@ def test_size_var_bound(): assert bd.max_value == bd.POS_INF +def test_let_bound(): + analyzer = tvm.arith.Analyzer() + x = te.var("x") + bd = analyzer.const_int_bound(tvm.tir.Let(x, 1, x + 1)) + assert bd.min_value == 2 + assert bd.max_value == 2 + + if __name__ == "__main__": + test_let_bound() test_dtype_bound() test_cast_bound() test_add_sub_bound() diff --git a/tests/python/unittest/test_arith_modular_set.py b/tests/python/unittest/test_arith_modular_set.py index 01180d2efb69..7d9f739f9d12 100644 --- a/tests/python/unittest/test_arith_modular_set.py +++ b/tests/python/unittest/test_arith_modular_set.py @@ -159,8 +159,17 @@ def test_intersect(): assert m.coeff == 105 assert m.base == 23 +def test_let(): + analyzer = tvm.arith.Analyzer() + x = te.var("x") + y = te.var("y") + m = analyzer.modular_set(tvm.tir.Let(x, y * 10, x + 1)) + m.coeff = 10 + m.base = 1 + if __name__ == "__main__": + test_let() test_cast() test_add_sub() test_mul() diff --git a/tests/python/unittest/test_tir_analysis_verify_ssa.py b/tests/python/unittest/test_tir_analysis_verify_ssa.py index 8a15c3628074..57dd8261a2c6 100644 --- a/tests/python/unittest/test_tir_analysis_verify_ssa.py +++ b/tests/python/unittest/test_tir_analysis_verify_ssa.py @@ -27,6 +27,16 @@ def test_verify_ssa(): assert(not tvm.tir.analysis.verify_ssa( tvm.tir.PrimFunc([x, y], tvm.tir.LetStmt(x, 1, z)))) +def test_verify_weak_let_ssa(): + x = te.var('x') + z1 = tvm.tir.Let(x, 1, x + 1) + z2 = tvm.tir.Let(x, 2, x + 2) + + assert(tvm.tir.analysis.verify_ssa( + tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 + z1)))) + assert(not tvm.tir.analysis.verify_ssa( + tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 * z2)))) if __name__ == "__main__": test_verify_ssa() + test_verify_weak_let_ssa() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index ab730cd63d1e..c182d9ea4dad 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -274,7 +274,8 @@ def test_prim_func(): func = tvm.tir.PrimFunc( [x, y, b], stmt) - + # make sure we can print + func.astext() assert func.buffer_map[func.params[2]].same_as(b) assert len(func.buffer_map) == 1 diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index a69c9d36c693..0516b4a84d65 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -81,6 +81,20 @@ def test_vectorize_with_if(): assert isinstance(stmt.else_case, tvm.tir.For) +def test_vectorize_let(): + v = tvm.tir.Var("v", "float32") + ib = tvm.tir.ir_builder.create() + A = ib.pointer("float32", name="A") + with ib.for_range(0, 4, for_type="vectorize") as i: + ib.emit(lambda body: tvm.tir.LetStmt(v, A[i] + 1, body)) + A[i] = v + 2 + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], ib.get())) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.LetStmt) + assert stmt.value.dtype == "float32x4" + + def test_vectorize_with_le_cond(): n = te.var('n') ib = tvm.tir.ir_builder.create() @@ -153,3 +167,4 @@ def test_vectorize_if_then_else(): test_vectorize_if_then_else() test_vectorize_with_le_cond() test_vectorize_with_ge_cond() + test_vectorize_let()