From fb15019c10e2f1761a097ce4e5f51d2a5461298b Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Wed, 20 Feb 2019 12:22:51 +0300 Subject: [PATCH 01/10] [TVM] Zero elimination --- include/tvm/ir_operator.h | 30 + python/tvm/testing.py | 184 ++ src/op/op_util.cc | 43 + src/op/op_util.h | 40 + src/pass/zero_elimination.cc | 1719 +++++++++++++++++ src/pass/zero_elimination.h | 239 +++ .../unittest/test_pass_zero_elimination.py | 464 +++++ 7 files changed, 2719 insertions(+) create mode 100644 src/pass/zero_elimination.cc create mode 100644 src/pass/zero_elimination.h create mode 100644 tests/python/unittest/test_pass_zero_elimination.py diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index c2cdc5e7a923..09a046228b5f 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -85,6 +85,16 @@ inline const uint64_t* as_const_uint(const Expr& x) { */ inline bool is_const_int(const Expr& x, int64_t value); +/*! + * \brief Check if the given expr is a const of any type equal to the given integer value. + * \param e The expression. + * \param value The value to compare to. + * \return Whether the expression is a const equal to the value. + * \tparam ValueType The value type + */ +template +inline bool is_const_value(const Expr& e, ValueType value); + /*! * \brief Check whether stmt is nop. * \param stmt The input statement @@ -519,6 +529,26 @@ inline bool is_const_int(const Expr& x, int64_t value) { return false; } +template +inline bool is_const_value(const Expr& e, ValueType value) { + static_assert(std::is_integral::value, + "Comparison to non-integer values is forbidden."); + // This implementation was copy-pasted from HalideIR + if (const ir::IntImm* i = e.as()) { + return i->value == value; + } else if (const ir::UIntImm* i = e.as()) { + return (value >= 0) && (i->value == (uint64_t)value); + } else if (const ir::FloatImm* i = e.as()) { + return i->value == value; + } else if (const ir::Cast* c = e.as()) { + return is_const_value(c->value, value); + } else if (const ir::Broadcast* b = e.as()) { + return is_const_value(b->value, value); + } else { + return false; + } +} + inline bool is_no_op(const Stmt& stmt) { if (!stmt.defined()) return true; if (const auto* op = stmt.as()) { diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 1a6666bdee2a..afdca6a19720 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -1,6 +1,7 @@ """ TVM testing utilities """ import logging import numpy as np +import tvm def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): """ Version of np.testing.assert_allclose with `atol` and `rtol` fields set @@ -145,3 +146,186 @@ def compare_derivative(j, n_der, grad): logging.info("Numerical grad test wrt '%s' of shape %s passes, " "dist = %f, max_diff = %f, avg_diff = %f", x_name, grad.shape, dist, max_diff, avg_diff) + + +class PerformanceEstimate: + """A result of static performance estimation. + + Parameters + ---------- + iterations : int + The total number of iterations of all the loops. + + multiplications : int + The total number of expensive operations like multiplications. + + memory : int + The amount of memory to allocate. + """ + def __init__(self, iterations=0, multiplications=0, memory=0): + self.iterations = iterations + self.multiplications = multiplications + self.memory = memory + + def as_tuple(self): + return (self.iterations, self.multiplications, self.memory) + + def __add__(self, other): + return PerformanceEstimate(iterations=self.iterations + other.iterations, + multiplications=self.multiplications + other.multiplications, + memory=self.memory + other.memory) + + def max(self, other): + return PerformanceEstimate( + iterations=max(self.iterations, other.iterations), + multiplications=max(self.multiplications, other.multiplications), + memory=max(self.memory, other.memory)) + + def times(self, iters): + return PerformanceEstimate(iterations=self.iterations*iters, + multiplications=self.multiplications*iters, + memory=self.memory) + + def __repr__(self): + return "PerformanceEstimate(iterations={}, multiplications={}, memory={})".format( + self.iterations, self.multiplications, self.memory) + + def __le__(self, other): + return \ + self.iterations <= other.iterations and \ + self.multiplications <= other.multiplications and \ + self.memory <= other.memory + + +def estimate_performance(s, param_values=None, processed_ops=None): + """Statically estimate performance of statements, expressions and tensors. Note that the + estimate is very rough, it mustn't be used to predict future performance, its only purpose is + to detect possible performance regressions. + + Parameters + ---------- + s + A statement, an expression, a tensor, an operation, or a list + of any of the above. + + param_values : Dict[tvm.expr.Var, int], optional + Values for parameters (free variables). + + Returns + ------- + estimate : PerformanceEstimate + """ + from tvm import stmt + from tvm import expr + + if param_values is None: + param_values = {} + + if processed_ops is None: + processed_ops = {} + res = estimate_performance(s, param_values=param_values, processed_ops=processed_ops) + for op_est in processed_ops.values(): + res += op_est + return res + + def est(expression, param_values=param_values, processed_ops=processed_ops): + return estimate_performance(expression, + param_values=param_values, + processed_ops=processed_ops) + + def _eval(expression, param_values=param_values): + return tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expression, param_values)).value + + def _prod(elems): + res = 1 + for x in elems: + res *= x + return res + + if s is None or isinstance(s, (stmt.AssertStmt, stmt.Free, stmt.Prefetch, + expr.ConstExpr, expr.Var, tvm.tensor.PlaceholderOp)): + return PerformanceEstimate() + elif isinstance(s, list): + res = PerformanceEstimate() + for item in s: + res += est(item) + return res + elif s in processed_ops: + return PerformanceEstimate() + elif isinstance(s, stmt.Allocate): + mem = _prod([_eval(e) for e in s.extents]) + return est(s.condition) + est(s.body) + PerformanceEstimate(memory=mem) + elif isinstance(s, stmt.Block): + return est(s.first) + est(s.rest) + elif isinstance(s, stmt.Evaluate): + return est(s.value) + elif isinstance(s, stmt.For): + body_est = est(s.body) + body_est.iterations = max(1, body_est.iterations) + return body_est.times(_eval(s.extent)) + elif isinstance(s, stmt.IfThenElse): + return est(s.condition) + est(s.then_case) + est(s.else_case) + elif isinstance(s, stmt.LetStmt): + return est(s.value) + est(s.body) + elif isinstance(s, (stmt.ProducerConsumer, stmt.AttrStmt)): + return est(s.body) + elif isinstance(s, stmt.Provide): + return est(s.value) + elif isinstance(s, stmt.Realize): + return est(s.condition) + est(s.body) + elif isinstance(s, stmt.Store): + return est(s.value) + est(s.index) + est(s.predicate) + elif isinstance(s, (expr.Mul, expr.Div, expr.Mod)): + return est(s.a) + est(s.b) + PerformanceEstimate(multiplications=1) + elif isinstance(s, (expr.BinaryOpExpr, expr.CmpExpr, expr.LogicalExpr)): + if not hasattr(s, 'b'): + return est(s.a) + return est(s.a) + est(s.b) + elif isinstance(s, expr.Call): + res = PerformanceEstimate() + for a in s.args: + res += est(a) + if s.call_type == expr.Call.Halide: + # The estimate is added to processed_ops, we don't need the result here + est(s.func) + elif s.name == "tvm_if_then_else": + pass + else: + # expr.If it is a non-halide call (e.g. exp or log), consider it a mul + res += PerformanceEstimate(multiplications=1) + return res + elif isinstance(s, expr.Cast): + return est(s.value) + elif isinstance(s, expr.Load): + return est(s.index) + est(s.predicate) + elif isinstance(s, expr.Select): + return est(s.condition) + est(s.true_value) + est(s.false_value) + elif isinstance(s, expr.Reduce): + iterations = _prod([_eval(iv.dom.extent) for iv in s.axis]) + res = PerformanceEstimate() + for id_elem in s.combiner.identity_element: + res += est(id_elem) + on_each_iter = est(s.condition) + for src in s.source: + on_each_iter += est(src) + for comb_res in s.combiner.result: + on_each_iter += est(comb_res) + on_each_iter.iterations = max(1, on_each_iter.iterations) + return res + on_each_iter.times(iterations) + elif isinstance(s, tvm.tensor.Tensor): + return est(s.op) + elif isinstance(s, tvm.tensor.ComputeOp): + iterations = _prod([_eval(iv.dom.extent) for iv in s.axis]) + if s.reduce_axis: + res = est(s.body[0]) + else: + res = PerformanceEstimate() + for b in s.body: + res += est(b) + res.iterations = max(1, res.iterations) + res = res.times(iterations) + PerformanceEstimate(memory=iterations*len(s.body)) + processed_ops[s] = res + return PerformanceEstimate() + + raise ValueError("Don't know how to estimate performance of {} of type {}" + .format(s, type(s))) diff --git a/src/op/op_util.cc b/src/op/op_util.cc index b18552d5c562..4231f336a01b 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -245,5 +245,48 @@ ir::ForType IterVarTypeToForType(IterVarType iter_type) { } } +Tensor TensorFromExpr(const Expr& expr, const Array& axis, + const std::string& name, const std::string& tag, + const Map& attrs) { + Array new_bodies; + int new_value_index = 0; + + // If this is a reduction then we have to clone its body + if (const Reduce* red = expr.as()) { + new_value_index = red->value_index; + + for (size_t i = 0; i < red->source.size(); ++i) { + Expr ith_red = Reduce::make(red->combiner, red->source, red->axis, red->condition, i); + new_bodies.push_back(ith_red); + } + } else { + new_value_index = 0; + new_bodies.push_back(expr); + } + + return ComputeOpNode::make(name, tag, attrs, axis, new_bodies).output(new_value_index); +} + +Tensor TransformBody(const Tensor& tensor, + std::function&)> func) { + if (const ComputeOpNode* op = tensor->op.as()) { + // Transform only one body + Expr new_body = func(op->body[tensor->value_index], op->axis); + + // If the body didn't change then we can return the same tensor + if (new_body.same_as(op->body[tensor->value_index])) { + return tensor; + } + + return TensorFromExpr(new_body, op->axis, op->name, op->tag, op->attrs); + } else { + return tensor; + } +} + +Tensor TransformBody(const Tensor& tensor, std::function func) { + return TransformBody(tensor, [func](const Expr& e, const Array&) { return func(e); }); +} + } // namespace op } // namespace tvm diff --git a/src/op/op_util.h b/src/op/op_util.h index de2e44c2ed59..da7987f7162f 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "../pass/ir_util.h" #include "../pass/arg_binder.h" #include "../schedule/message_passing.h" @@ -84,6 +85,45 @@ IterVarType ForTypeToIterVarType(ir::ForType for_type); */ ir::ForType IterVarTypeToForType(IterVarType iter_type); +/*! + * \brief Create a tensor from an expression. The expression may be a reduction, in which + * case its body will be correctly duplicated if it is a multi-valued reduction. + * + * \param expr The expr which will be the tensor's body. + * \param axis The input variables with ranges. + * \param name The tensor's name. + * \param tag The tensor's tag. + * \param attrs The tensor's attrs. + * \return A tensor. + */ +Tensor TensorFromExpr(const Expr& expr, const Array& axis, + const std::string& name = "tensor", const std::string& tag = "", + const Map& attrs = {}); + +/*! + * \brief Transform the body of a tensor if it is a compute tensor, otherwise return it + * unchanged. Note that if the compute returns a tuple, it transforms only one element, + * other elements are discarded. + * + * \param tensor The tensor to transform. + * \param func The transformation function working on expressions and additionally taking + * the array of the tensor's itervars. + * \return The transformed tensor. + */ +Tensor TransformBody(const Tensor& tensor, + std::function&)> func); + +/*! + * \brief Transform the body of a tensor if it is a compute tensor, otherwise return it + * unchanged. Note that if the compute returns a tuple, it transforms only one element, + * other elements are discarded. + * + * \param tensor The tensor to transform. + * \param func The transformation function (working on expressions). + * \return The transformed tensor. + */ +Tensor TransformBody(const Tensor& tensor, std::function func); + } // namespace op } // namespace tvm #endif // TVM_OP_OP_UTIL_H_ diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc new file mode 100644 index 000000000000..56e4006c824a --- /dev/null +++ b/src/pass/zero_elimination.cc @@ -0,0 +1,1719 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file zero_elimination.cc + * \brief Transform tensors in such a way as to eliminate summation over zeros. + */ +#include "zero_elimination.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "arithmetic/ModulusRemainder.h" +#include "../op/op_util.h" + +namespace tvm { +namespace ir { + +using HalideIR::Internal::gcd; +using HalideIR::Internal::lcm; + +struct ExprLess { + bool operator()(const Expr& l, const Expr& r) const { + return Compare(l, r) < 0; + } +}; + +struct ExprEq { + bool operator()(const Expr& l, const Expr& r) const { + return Compare(l, r) == 0; + } +}; + +// Merge two maps, prefer the right one on conflict +template +Map Merge(Map original, const Map& update) { + for (const auto& p : update) { + original.Set(p.first, p.second); + } + return std::move(original); +} + +// Concatenate two arrays +template +Array Concat(Array a, const Array& b) { + for (const auto& x : b) { + a.push_back(x); + } + return std::move(a); +} + +// Combine all expressions from the container using &&. +template +Expr All(const container& c) { + Expr res; + for (const auto& e : c) { + if (res.get()) { + res = res && e; + } else { + res = e; + } + } + if (res.get()) { + return res; + } else { + return const_true(); + } +} + +// Create a select statement of the form cond ? on_true : 0 +Expr SelectElseZero(const Expr& cond, const Expr& on_true) { + return Select::make(cond, on_true, make_zero(on_true.type())); +} + +// Simplify the expression as thoroughly as possible by using all available simplifiers. +Expr SuperSimplify(Expr e, const Map& vranges = Map()) { + // For some reason no simplifier can detect that there is only one value of the variable + std::unordered_map vmap; + for (const auto& var_range : vranges) { + if (is_const_int(var_range.second->extent, 1)) { + vmap[var_range.first.get()] = var_range.second->min; + } + } + if (!vmap.empty()) { + e = Substitute(e, vmap); + } + + return CanonicalSimplify(Simplify(CanonicalSimplify(e, vranges), vranges), vranges); +} + +// Provability check that uses SuperSimplify +bool CanProve(Expr e, const Map& vranges = Map()) { + return is_one(SuperSimplify(e, vranges)); +} + +class ExprFreeVarsVisitor : public IRVisitor { + public: + std::vector free_array; + std::unordered_set bound; + std::unordered_set free; + + virtual void Visit(const NodeRef& node) { + if (const Variable* v = node.as()) { + if (!bound.count(v) && !free.count(v)) { + free.insert(v); + free_array.push_back(Var(node.node_)); + } + } else { + IRVisitor::Visit(node); + } + } + + void Visit_(const Variable* op) { + CHECK(false) << "This case shouldn't happen"; + } + + void Visit_(const LetStmt* op) { + bound.insert(op->var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const For* op) { + bound.insert(op->loop_var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const Let* op) { + bound.insert(op->var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const Reduce* op) { + for (const auto& iv : op->axis) { + bound.insert(iv->var.get()); + } + IRVisitor::Visit_(op); + } + + void Visit_(const Store* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Allocate* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Free* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Load* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } +}; + +// Get free variables of an expression +Array ExprFreeVars(const Expr& expr) { + ExprFreeVarsVisitor visitor; + visitor.Visit(expr); + return visitor.free_array; +} + +// Clone iter vars and return both the new vars and the substitution from old to new. +std::pair, std::unordered_map> CloneIterVars( + const Array& vars) { + Array new_vars; + std::unordered_map vmap; + for (const IterVar& iv : vars) { + IterVar new_v = + IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), + iv->iter_type, iv->thread_tag); + new_vars.push_back(new_v); + vmap[iv->var.get()] = new_v; + } + return std::make_pair(std::move(new_vars), std::move(vmap)); +} + +// Clone reduction by cloning the axis variables. +Expr CloneReduction(const Expr& expr) { + if (const Reduce* red = expr.as()) { + Array new_axis; + std::unordered_map vmap; + std::tie(new_axis, vmap) = CloneIterVars(red->axis); + + Array src_with_newaxis; + for (const auto& src : red->source) { + src_with_newaxis.push_back(Substitute(src, vmap)); + } + + return Reduce::make(red->combiner, src_with_newaxis, + new_axis, Substitute(red->condition, vmap), red->value_index); + } else { + return expr; + } +} + +// Convert an array of itervars to an array of inequalities +Array IterVarsToInequalities(const Array& itervars) { + Array res; + for (const IterVar& v : itervars) { + res.push_back(GE::make(v->var, v->dom->min)); + res.push_back(LT::make(v->var, v->dom->min + v->dom->extent)); + } + return res; +} + +// Convert an array of itervars to a map from vars to ranges +Map IterVarsToMap(const Array& itervars) { + Map res; + for (const IterVar& v : itervars) { + res.Set(v->var, v->dom); + } + return res; +} + +// Convert an array of itervars to an array of vars +Array IterVarsToVars(const Array& itervars) { + Array res; + for (const IterVar& v : itervars) { + res.push_back(v->var); + } + return res; +} + +// Given a map from vars to ranges create an array of itervars +Array IterVarsFromMap(const Array& vars, const Map& vranges, + IterVarType iter_type = kDataPar, std::string thread_tag = "") { + Array res; + for (const Var& v : vars) { + CHECK(vranges.count(v)) << "A range for the variable " << v + << " was not provided in map " << vranges; + res.push_back(IterVarNode::make(vranges[v], v, iter_type, thread_tag)); + } + return res; +} + +// Return true if this combiner is just a sum. +bool IsSumCombiner(const CommReducer& combiner) { + if (combiner->result.size() != 1) { + return false; + } + + if (!is_const_value(SuperSimplify(combiner->identity_element[0]), 0)) { + return false; + } + + return is_const_value(SuperSimplify(combiner->result[0] - + (combiner->lhs[0] + combiner->rhs[0])), + 0); +} + +// Return true if zero may be factored out of a reduction with this combiner. +bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index) { + if (!is_const_value(combiner->identity_element[value_index], 0)) { + return false; + } + + Expr zero = make_zero(combiner->result[value_index].type()); + Expr in = Substitute(combiner->result[value_index], + {{combiner->lhs[value_index], zero}, + {combiner->rhs[value_index], zero}}); + in = SuperSimplify(in); + + return is_const_value(in, 0); +} + +Expr InlineThisCall(const Expr& expr) { + if (const Call* op = expr.as()) { + if (op->call_type == Call::CallType::Halide) { + if (const ComputeOpNode* op_comp = op->func.as()) { + Array tensor_axes; + for (const auto& var : op_comp->axis) { + tensor_axes.push_back(var->var); + } + + Stmt inlined = Inline(Evaluate::make(expr), op->func, tensor_axes, + op_comp->body[op->value_index]); + if (const ir::Evaluate* ev = inlined.as()) { + // If it is a reduction, clone it + return CloneReduction(ev->value); + } + } + } + } + + return expr; +} + +Tensor InlineTailCall(const Tensor& tensor) { + return op::TransformBody(tensor, InlineThisCall); +} + +class InlineTensorsMutator : public IRMutator { + public: + explicit InlineTensorsMutator(const Array& inlineable, bool inline_reductions = false) + : inline_reductions_(inline_reductions) { + for (const Tensor& tensor : inlineable) { + inlineable_.emplace(tensor->op.operator->(), tensor->value_index); + } + } + + Expr Mutate_(const Call* op, const Expr& e) { + if (op->call_type == Call::CallType::Halide) { + const ComputeOpNode* op_comp = op->func.as(); + if (inlineable_.empty() || inlineable_.count({op_comp, op->value_index})) { + if (op_comp && (inline_reductions_ || !op_comp->body[0].as())) { + Array tensor_axes; + for (const auto& var : op_comp->axis) { + tensor_axes.push_back(var->var); + } + + Stmt inlined = Inline(Evaluate::make(e), op->func, tensor_axes, + op_comp->body[op->value_index]); + if (const ir::Evaluate* ev = inlined.as()) { + // If it is a reduction, clone it + return Mutate(ev->value); + } + } + } + } + + return e; + } + + private: + std::set> inlineable_; + bool inline_reductions_; +}; + +Expr InlineTensors(const Expr& expr, const Array& inlineable, + bool inline_reductions) { + return InlineTensorsMutator(inlineable, inline_reductions).Mutate(expr); +} + +Tensor InlineTensors(const Tensor& tensor, const Array& inlineable, + bool inline_reductions) { + auto transformation = + [inlineable, inline_reductions](const Expr& e) { + return InlineTensorsMutator(inlineable, inline_reductions).Mutate(e); }; + return op::TransformBody(tensor, transformation); +} + + +struct NonzeronessConditionResult { + Expr cond; + Expr value; + + Expr to_expr() const { + return SelectElseZero(cond, value); + } +}; + +class NonzeronessConditionFunctor + : public ExprFunctor { + public: + NonzeronessConditionResult NonzeronessCondition(const Expr& e) { + return VisitExpr(e, e); + } + + result_type VisitExpr_(const Variable*, const Expr& e) final { return Default_(e); } + result_type VisitExpr_(const IntImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const UIntImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const FloatImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const StringImm*, const Expr& e) final { return Default_(e); } + result_type VisitExpr_(const Add* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Sub* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Mul* op, const Expr& e) final { return BinOpMulLike_(op, e); } + result_type VisitExpr_(const Div* op, const Expr& e) final { return BinOpDivLike_(op, e); } + result_type VisitExpr_(const Mod* op, const Expr& e) final { return BinOpDivLike_(op, e); } + result_type VisitExpr_(const Min* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Max* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const EQ* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const NE* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const LE* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const LT* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const GE* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const GT* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const Not* op, const Expr& e) final { return Bool_(op, e); } + + result_type VisitExpr_(const Cast* op, const Expr& e) final { + if (op->value.type().is_bool()) { + return {op->value, make_const(e.type(), 1)}; + } else { + auto nz_a = NonzeronessCondition(op->value); + + if (nz_a.value.same_as(op->value)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, Cast::make(op->type, nz_a.value)}; + } + } + } + + result_type VisitExpr_(const Select* op, const Expr& e) final { + return SelectLike_(e, op->condition, op->true_value, op->false_value, Select::make); + } + + result_type VisitExpr_(const Call* op, const Expr& e) final { + if (op->name == intrinsic::tvm_if_then_else) { + return SelectLike_(e, op->args[0], op->args[1], op->args[2], if_then_else); + } else { + return Default_(e); + } + } + + NonzeronessConditionResult Default_(const Expr& e) { + return {const_true(), e}; + } + + template + NonzeronessConditionResult Const_(const TNode* op, const Expr& e) { + if (op->value == 0) { + return {const_false(), e}; + } else { + return {const_true(), e}; + } + } + + template + NonzeronessConditionResult SelectLike_(const Expr& e, const Expr& cond, const Expr& true_val, + const Expr& false_val, make_select_type make_select) { + auto nz_a = NonzeronessCondition(true_val); + auto nz_b = NonzeronessCondition(false_val); + + if (is_const_value(nz_b.value, 0)) { + Expr new_cond = SuperSimplify(nz_a.cond && cond); + return {new_cond, nz_a.value}; + } + + if (is_const_value(nz_a.value, 0)) { + Expr new_cond = SuperSimplify(nz_b.cond && !cond); + return {new_cond, nz_b.value}; + } + + Expr new_cond = + SuperSimplify(Or::make(cond && nz_a.cond, + !cond && nz_b.cond)); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, e}; + } else { + return {new_cond, make_select(cond, nz_a.value, nz_b.value)}; + } + } + + template + NonzeronessConditionResult BinOpAddLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + auto nz_b = NonzeronessCondition(op->b); + + if (Equal(nz_a.cond, nz_b.cond)) { + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, TNode::make(nz_a.value, nz_b.value)}; + } + } else { + Expr new_cond = SuperSimplify(Or::make(nz_a.cond, nz_b.cond)); + Expr new_a = Equal(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr(); + Expr new_b = Equal(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr(); + Expr new_expr = TNode::make(new_a, new_b); + return {new_cond, new_expr}; + } + } + + template + NonzeronessConditionResult BinOpMulLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + auto nz_b = NonzeronessCondition(op->b); + + Expr new_cond = SuperSimplify(nz_a.cond && nz_b.cond); + + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {new_cond, e}; + } else { + return {new_cond, TNode::make(nz_a.value, nz_b.value)}; + } + } + + template + NonzeronessConditionResult BinOpDivLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + + if (nz_a.value.same_as(op->a)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, TNode::make(nz_a.value, op->b)}; + } + } + + template + NonzeronessConditionResult Bool_(const TNode* op, const Expr& e) { + return {e, make_const(e.type(), 1)}; + } +}; + +NonzeronessConditionResult NonzeronessCondition(const Expr& expr) { + return NonzeronessConditionFunctor().NonzeronessCondition(expr); +} + +Expr LiftNonzeronessCondition(const Expr& expr) { + return NonzeronessCondition(expr).to_expr(); +} + + +class NormalizeComparisonsMutator : public IRMutator { + public: + virtual Expr Mutate_(const EQ* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const NE* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const LT* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const LE* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const GT* op, const Expr& e) { return Make(op->b, op->a); } + virtual Expr Mutate_(const GE* op, const Expr& e) { return Make(op->b, op->a); } + + private: + template + Expr Make(const Expr& a, const Expr& b) { + // rewrite LT to LE for ints + if (std::is_same::value && (a.type().is_int() || a.type().is_uint())) { + return LE::make(SuperSimplify(a - b + 1), make_zero(a.type())); + } + return TNode::make(SuperSimplify(a - b), make_zero(a.type())); + } +}; + +// Rewrite every comparison into the form a == 0, a != 0, a <= 0, and sometimes for floats a < 0 +Expr NormalizeComparisons(const Expr& expr) { + return NormalizeComparisonsMutator().Mutate(expr); +} + + +struct FactorOutAtomicFormulasResult { + std::vector atomic_formulas; + Expr rest; + + Expr to_expr() const { + Expr res = rest; + for (const Expr& e : atomic_formulas) { + res = And::make(e, res); + } + return res; + } +}; + +class FactorOutAtomicFormulasFunctor + : public ExprFunctor { + public: + result_type Atomic_(const Expr& e) { + return {{e}, make_const(e.type(), 1)}; + } + + result_type VisitExpr_(const Variable*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const Call*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const IntImm*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const UIntImm*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const EQ*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const NE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const LE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const LT*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const GE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const GT*, const Expr& e) final { return Atomic_(e); } + + result_type VisitExpr_(const And* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + std::vector res; + res.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::set_union(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + return {res, res_a.rest && res_b.rest}; + } + + result_type VisitExpr_(const Mul* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + std::vector res; + res.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::set_union(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + return {res, res_a.rest * res_b.rest}; + } + + result_type VisitExpr_(const Or* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + std::vector res; + res.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); + std::set_intersection(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + std::vector new_cond_a; + new_cond_a.reserve(res_a.atomic_formulas.size() - res.size()); + std::set_difference(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res.begin(), res.end(), + std::back_inserter(new_cond_a), + ExprLess()); + + std::vector new_cond_b; + new_cond_b.reserve(res_b.atomic_formulas.size() - res.size()); + std::set_difference(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + res.begin(), res.end(), + std::back_inserter(new_cond_b), + ExprLess()); + + res_a.atomic_formulas = std::move(new_cond_a); + res_b.atomic_formulas = std::move(new_cond_b); + + Expr new_rest = Or::make(res_a.to_expr(), res_b.to_expr()); + + return {res, new_rest}; + } +}; + +// Transform the given formula into an array of atomic formulas and a non-atomic residual. +FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const Expr& e) { + return FactorOutAtomicFormulasFunctor().VisitExpr(e, e); +} + + +struct EliminateDivModResult { + Expr expr; + Map substitution; + Array new_variables; + Array conditions; + Map ranges; +}; + +class EliminateDivModMutator : public IRMutator { + public: + Map substitution; + Array new_variables; + Array conditions; + Map ranges; + + explicit EliminateDivModMutator(Map ranges) + : ranges(ranges) {} + + virtual Expr Mutate_(const Div* op, const Expr& e) { + const IntImm* imm = op->b.as(); + if (imm && imm->value > 0) { + auto it = expr_to_vars_.find({op->a.get(), imm->value}); + if (it != expr_to_vars_.end()) { + return it->second.first; + } + + Expr mutated_a = Mutate(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { + return var_pair_opt.value().first; + } else { + return Div::make(mutated_a, Mutate(op->b)); + } + } + + return Div::make(Mutate(op->a), Mutate(op->b)); + } + + virtual Expr Mutate_(const Mod* op, const Expr& e) { + const IntImm* imm = op->b.as(); + if (imm && imm->value > 0) { + auto it = expr_to_vars_.find({op->a.get(), imm->value}); + if (it != expr_to_vars_.end()) { + return it->second.second; + } + + Expr mutated_a = Mutate(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { + return var_pair_opt.value().second; + } else { + return Mod::make(mutated_a, Mutate(op->b)); + } + } + + return Mod::make(Mutate(op->a), Mutate(op->b)); + } + + private: + dmlc::optional> AddNewVarPair(const Expr& e, const Expr& mut, int64_t val) { + using tresult = dmlc::optional>; + + Expr val_e = make_const(e.type(), val); + idx_ += 1; + + std::unordered_map var_intsets; + for (const auto& p : ranges) { + var_intsets[p.first.get()] = IntSet::range(p.second); + } + + Range div_range = EvalSet(mut / val_e, var_intsets).cover_range(Range()); + Range mod_range = EvalSet(mut % val_e, var_intsets).cover_range(Range()); + + if (!div_range.get() || !mod_range.get()) { + LOG(WARNING) << "EliminateDivMod: won't eliminate div or mod of expr " << e + << " because its bounds cannot be inferred"; + return tresult(); + } + + auto div = Var("div" + std::to_string(idx_), e.type()); + auto mod = Var("mod" + std::to_string(idx_), e.type()); + + new_variables.push_back(div); + new_variables.push_back(mod); + + substitution.Set(div, mut / val_e); + substitution.Set(mod, mut % val_e); + + ranges.Set(div, div_range); + ranges.Set(mod, mod_range); + + conditions.push_back(mut == div*val_e + mod); + + if (!CanProve(mod_range->extent <= val_e)) { + LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod of expr " << e + << " (probably it may change its sign)"; + // We cannot prove that mod is unique, so add additional condition + conditions.push_back(Select::make(e >= 0, mod >= 0, mod <= 0)); + } + + auto p = std::make_pair(div, mod); + expr_to_vars_[{e.get(), val}] = p; + return tresult(p); + } + + int idx_{0}; + std::map, std::pair> + expr_to_vars_; +}; + +// replace every subexpr of the form e/const and e % const with a new variable +EliminateDivModResult EliminateDivMod(const Expr& expr, Map ranges) { + EliminateDivModResult res; + EliminateDivModMutator mutator(ranges); + res.expr = mutator.Mutate(expr); + res.conditions = std::move(mutator.conditions); + res.new_variables = std::move(mutator.new_variables); + res.substitution = std::move(mutator.substitution); + res.ranges = std::move(mutator.ranges); + return res; +} + +// run EliminateDivMod from the condition of a reduction +Expr EliminateDivModFromReductionCondition(const Expr& expr, + Map vranges = Map()) { + if (const Reduce* red = expr.as()) { + for (const IterVar& iv : red->axis) { + vranges.Set(iv->var, iv->dom); + } + + auto elim_res = EliminateDivMod(red->condition, vranges); + + vranges = elim_res.ranges; + + Array new_axis = + Concat(red->axis, IterVarsFromMap(elim_res.new_variables, vranges, kCommReduce)); + + Expr new_cond = elim_res.expr && All(elim_res.conditions); + + return Reduce::make(red->combiner, red->source, new_axis, new_cond, red->value_index); + } else { + return expr; + } +} + + +VarBounds VarBounds::substitute(const Map& subst) const { + auto apply_fun = [&subst](const Expr& e) { return Substitute(e, subst); }; + return {Substitute(coef, subst), + UpdateArray(lower, apply_fun), + UpdateArray(equal, apply_fun), + UpdateArray(upper, apply_fun)}; +} + +Array SolveSystemOfInequalitiesResult::as_conditions() const { + Array res; + for (const Var& v : variables) { + auto it = bounds.find(v.get()); + CHECK(it != bounds.end()); + const VarBounds& bnds = it->second; + Expr lhs = bnds.coef * v; + for (const Expr& rhs : bnds.equal) { + res.push_back(EQ::make(lhs, rhs)); + } + for (const Expr& rhs : bnds.lower) { + res.push_back(GE::make(lhs, rhs)); + } + for (const Expr& rhs : bnds.upper) { + res.push_back(LE::make(lhs, rhs)); + } + } + for (const Expr& e : other_conditions) { + res.push_back(e); + } + return res; +} + +// Rewrite the system of inequalities using Fourier-Motzkin elimination +// Note that variable ranges help a lot, so this parameter is even non-optional +SolveSystemOfInequalitiesResult SolveSystemOfInequalities(const Array& inequalities, + const Array& variables, + const Map& vranges) { + SolveSystemOfInequalitiesResult res; + res.variables = variables; + + // The algorithm consists in doing the following things for each variable v + // - Take formulas from `current` and classify them according to polarity wrt v + // - Combine each formula of positive polarity (wrt v) with each formula of negative polarity + // - Put the resulting combinations into `new_current` along with unclassifiable formulas + // - Replace `current` with `new_current` and move to the next variable + + // current and new_current are sorted to enable some heuristics + std::set current; + std::set new_current; + // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_pos; + // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_neg; + + // formulas we don't know what to do with + std::vector rest; + + // A helper that adds an inequality to new_current if it's not obviously redundant + auto add_to_new_current = [&new_current, &vranges] (const Expr& new_ineq) { + if (CanProve(new_ineq, vranges)) { + // redundant: follows from the vranges + return; + } + if (const LE* new_le = new_ineq.as()) { + // A heuristic: check if the new inequality is a consequence of one + // of its future neighbors (in this case don't add it) or if a future neighbor is + // a consequence of the new ineq (in which case remove the neighbor) + auto it_neighbor = new_current.lower_bound(new_ineq); + if (it_neighbor != new_current.begin()) { + const LE* le = std::prev(it_neighbor)->as(); + if (le && CanProve(new_le->a - le->a <= 0, vranges)) { + return; + } else if (le && CanProve(le->a - new_le->a <= 0, vranges)) { + new_current.erase(std::prev(it_neighbor)); + } + } + // Check the other neighbor + if (it_neighbor != new_current.end()) { + const LE* le = it_neighbor->as(); + if (le && CanProve(new_le->a - le->a <= 0, vranges)) { + return; + } else if (le && CanProve(le->a - new_le->a <= 0, vranges)) { + it_neighbor = new_current.erase(it_neighbor); + } + } + + new_current.insert(it_neighbor, new_ineq); + } else { + new_current.insert(new_ineq); + } + }; + + // Simplify each inequality into the form `expr <= 0` and add to new_current formulas + for (const Expr& ineq : inequalities) { + add_to_new_current(NormalizeComparisons(SuperSimplify(ineq, vranges))); + } + + std::swap(current, new_current); + + for (const Var& v : variables) { + CHECK(!res.bounds.count(v.get())) << + "Variable " << v << " appears several times in the `variables` which might be a bug"; + + new_current.clear(); + coef_pos.clear(); + coef_neg.clear(); + + // Add bounds from vranges + if (vranges.count(v)) { + const Range& range = vranges[v]; + Expr range_lbound = SuperSimplify(range->min, vranges); + Expr range_ubound = SuperSimplify(range->min + range->extent - 1, vranges); + coef_neg.push_back({-1, range_lbound}); + coef_pos.push_back({1, -range_ubound}); + } + + // Take formulas from `current` and classify them according to polarity wrt v + for (const Expr& ineq : current) { + if (const LE* le = ineq.as()) { + Array coef = arith::DetectLinearEquation(le->a, {v}); + if (!coef.empty() && is_const(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to new_current + add_to_new_current(ineq); + } else if (coef0 > 0) { + coef_pos.push_back({coef0, coef[1]}); + } else if (coef0 < 0) { + coef_neg.push_back({coef0, coef[1]}); + } + continue; + } + } else if (const EQ* eq = ineq.as()) { + Array coef = arith::DetectLinearEquation(eq->a, {v}); + if (!coef.empty() && is_const(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to new_current + add_to_new_current(ineq); + } else if (coef0 > 0) { + // Equalities may be considered as pairs of two inequalities + coef_pos.push_back({coef0, coef[1]}); + coef_neg.push_back({-coef0, -coef[1]}); + } else if (coef0 < 0) { + coef_pos.push_back({-coef0, -coef[1]}); + coef_neg.push_back({coef0, coef[1]}); + } + continue; + } + } + + // if nothing worked, put it in rest + rest.push_back(ineq); + } + + // Combine each positive inequality with each negative one (by adding them together) + for (const auto& pos : coef_pos) { + for (const auto& neg : coef_neg) { + auto first_gcd = gcd(pos.first, -neg.first); + Expr c_pos = make_const(v.type(), neg.first/first_gcd); + Expr c_neg = make_const(v.type(), pos.first/first_gcd); + Expr new_lhs = c_neg*neg.second - c_pos*pos.second; + Expr new_ineq = LE::make(new_lhs, make_zero(pos.second.type())); + new_ineq = NormalizeComparisons(SuperSimplify(new_ineq, vranges)); + add_to_new_current(new_ineq); + } + } + + // Now we have to generate resulting (in)equalities for the variable v + + // Find the common denominator in a sense + // We will generate formulas of the form coef_lcm*v <= bound + int64_t coef_lcm = 1; + for (const auto& pos : coef_pos) { + coef_lcm = lcm(coef_lcm, pos.first); + } + for (const auto& neg : coef_neg) { + coef_lcm = lcm(coef_lcm, -neg.first); + } + + // The resulting lower and upper bounds stored in sorted vectors + std::vector upper_bounds; + std::vector lower_bounds; + upper_bounds.reserve(coef_pos.size()); + lower_bounds.reserve(coef_neg.size()); + + for (const auto& pos : coef_pos) { + Expr bound = make_const(v.type(), -coef_lcm/pos.first)*pos.second; + bound = SuperSimplify(bound, vranges); + // Don't add if any of the existing bounds is better + if (std::any_of(upper_bounds.begin(), upper_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound <= 0, + vranges); })) { + continue; + } + // Erase all worse bounds + upper_bounds.erase( + std::remove_if(upper_bounds.begin(), upper_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound >= 0, + vranges); }), + upper_bounds.end()); + // Add + upper_bounds.push_back(bound); + } + for (const auto& neg : coef_neg) { + Expr bound = make_const(v.type(), -coef_lcm/neg.first)*neg.second; + bound = SuperSimplify(bound, vranges); + // Don't add if any of the existing bounds is better + if (std::any_of(lower_bounds.begin(), lower_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound >= 0, + vranges); })) { + continue; + } + // Erase all worse bounds + lower_bounds.erase( + std::remove_if(lower_bounds.begin(), lower_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound <= 0, + vranges); }), + lower_bounds.end()); + // Add + lower_bounds.push_back(bound); + } + + // Sort the vectors and remove duplicates + for (std::vector* bounds : {&upper_bounds, &lower_bounds}) { + std::sort(bounds->begin(), bounds->end(), ExprLess()); + bounds->erase(std::unique(bounds->begin(), bounds->end(), ExprEq()), bounds->end()); + } + + // Bounds which are both lower and upper should go to equal... + std::vector equal; + equal.reserve(std::min(upper_bounds.size(), lower_bounds.size())); + std::set_intersection(upper_bounds.begin(), upper_bounds.end(), + lower_bounds.begin(), lower_bounds.end(), + std::back_inserter(equal), ExprLess()); + + // ...and be removed from upper bounds... + std::vector new_upper; + new_upper.reserve(upper_bounds.size() - equal.size()); + std::set_difference(upper_bounds.begin(), upper_bounds.end(), + equal.begin(), equal.end(), + std::back_inserter(new_upper), ExprLess()); + + // ...and from lower bounds. + std::vector new_lower; + new_lower.reserve(lower_bounds.size() - equal.size()); + std::set_difference(lower_bounds.begin(), lower_bounds.end(), + equal.begin(), equal.end(), + std::back_inserter(new_lower), ExprLess()); + + // Write it to the result. + auto& bnds = res.bounds[v.get()]; + bnds.coef = make_const(v.type(), coef_lcm); + bnds.equal = equal; + bnds.lower = new_lower; + bnds.upper = new_upper; + + std::swap(current, new_current); + } + + // Everything that is left goes to res.other_conditions + for (const Expr& e : current) { + Expr e_simp = SuperSimplify(e, vranges); + if (is_const_int(e_simp, 0)) { + // contradiction detected + res.other_conditions = {const_false()}; + return res; + } else if (is_const_int(e_simp, 1)) { + continue; + } else { + res.other_conditions.push_back(e_simp); + } + } + + for (const Expr& e : rest) + res.other_conditions.push_back(e); + + return res; +} + + +// Simplify an iteration domain. +DomainSimplificationResult SimplifyDomain(const Expr& cond, + const Array& axis, + Map vranges, + bool eliminate_div_mod) { + if (eliminate_div_mod) { + auto elim_res = EliminateDivMod(cond, vranges); + + Map new_vranges = elim_res.ranges; + Array new_axis = Concat(axis, elim_res.new_variables); + Expr new_cond = elim_res.expr && All(elim_res.conditions); + + auto res = SimplifyDomain(new_cond, new_axis, new_vranges, false); + + Map new_old_to_new; + for (const Var& v : axis) { + new_old_to_new.Set(v, res.old_to_new[v]); + } + + Map new_new_to_old; + for (const auto& pair : res.new_to_old) { + new_new_to_old.Set(pair.first, Substitute(pair.second, elim_res.substitution)); + } + + res.old_to_new = std::move(new_old_to_new); + res.new_to_old = std::move(new_new_to_old); + + return res; + } + + auto factoratomic_res = FactorOutAtomicFormulas(cond); + std::vector& atomic_formulas = factoratomic_res.atomic_formulas; + Expr rest_of_cond = factoratomic_res.rest; + + // Put rest_of_cond into the vector of atomic formulas so that we don't forget about it. + // Although rest_of_cond is not atomic, the subsequent functions won't complain about it. + atomic_formulas.push_back(rest_of_cond); + + // vars are variables from axis followed by all the other variables from vranges + Array vars = axis; + for (const auto& pair : vranges) { + bool already = false; + for (const Var& v : vars) { + already = already || v.same_as(pair.first); + } + if (!already) { + vars.push_back(pair.first); + } + } + + auto solved_system = SolveSystemOfInequalities(atomic_formulas, vars, vranges); + + DomainSimplificationResult res; + std::unordered_map new_var_intsets; + + // Initialize new_var_intsets with the old var intsets + for (const auto& pair : vranges) { + new_var_intsets[pair.first.get()] = IntSet::range(pair.second); + } + + // We process variables in the reverse direction to start with the most independent one. + // This order is needed to compute new ranges. + for (auto it = axis.rbegin(); it != axis.rend(); ++it) { + const Var& var = *it; + auto& bnd = solved_system.bounds[var.get()]; + // Note that we replace old vars with new ones + bnd = bnd.substitute(res.old_to_new); + if (is_one(bnd.coef) && !bnd.equal.empty()) { + // There is an equation of the form `v == expr`, so this variable can be completely removed. + // Note that we use the 0-th expression because they are ordered by complexity, so it must be + // the simplest one. + res.old_to_new.Set(var, bnd.equal[0]); + } else { + Array lowers = Concat(bnd.equal, bnd.lower); + Array uppers = Concat(bnd.equal, bnd.upper); + + // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the + // pair with the minimal difference between the upper and the lower. + // Note that the bounds are for v*coef, not for v (because we don't want complex expressions + // involving division). + + // The lower bound of the best pair so far + Expr best_lower = vranges[var]->min * bnd.coef; + // The difference between the upper and the lower of the best pair so far + Expr best_diff = (vranges[var]->extent - 1) * bnd.coef; + // The overapproximation of the best difference + Expr best_diff_over = best_diff; + + for (const Expr& low : lowers) { + for (const Expr& upp : uppers) { + Expr diff = SuperSimplify(upp - low, vranges); + // Since diff may depend on some other variables, we compute its overapproximation + Expr diff_over = EvalSet(diff, new_var_intsets).max(); + + if (diff_over.same_as(HalideIR::Internal::Interval::pos_inf)) { + continue; + } + + // If it is provable that the new one is strictly better than the current best one, + // then replace it. Note that we are biased towards earlier pairs which should be simpler. + if (CanProve(diff_over - best_diff_over < 0, vranges)) { + best_lower = low; + best_diff = diff; + best_diff_over = diff_over; + } + } + } + + if (is_const_int(best_diff, 0)) { + // In this case coef*iv = best_lower + // Don't create an itervar, just replace it everywhere with its min + res.old_to_new.Set(var, SuperSimplify(best_lower / bnd.coef, vranges)); + // To assure correctness, we have to add a condition that best_lower can be divided by coef + res.conditions.push_back(SuperSimplify(best_lower % bnd.coef == 0, vranges)); + } else { + std::string suffix = Equal(best_lower, vranges[var]->min * bnd.coef) ? "" : ".shifted"; + Var new_var = var.copy_with_suffix(suffix); + + // We will replace our iv with new_var + shift. + // We use rounding-up division to compute shift. Since we want to use a single formula + // without selects in as many cases as possible, we try to prove conditions manually. + Expr shift; + if (CanProve(best_lower <= 0, vranges)) { + shift = best_lower / bnd.coef; + } else if (CanProve(best_lower > -bnd.coef, vranges)) { + shift = (best_lower + bnd.coef - 1)/bnd.coef; + } else { + shift = Select::make(best_lower <= -bnd.coef, + best_lower / bnd.coef, + (best_lower + bnd.coef - 1)/bnd.coef); + } + shift = SuperSimplify(shift, vranges); + + Expr diff = SuperSimplify(best_diff_over / bnd.coef, vranges); + + if (is_const_int(diff, 0)) { + // Don't create an itervar, just replace it everywhere with its min + res.old_to_new.Set(var, shift); + } else { + res.old_to_new.Set(var, new_var + shift); + // Note that we are substituting old with new, so best_lower contains new var, + // that is we have to substitute new with old in best_lower here + res.new_to_old.Set(new_var, + SuperSimplify(var - Substitute(shift, res.new_to_old), vranges)); + + new_var_intsets[new_var.get()] = IntSet::interval(make_zero(new_var.type()), diff); + + // Add the new var to the resulting axis + auto range = Range(make_zero(new_var.type()), SuperSimplify(diff + 1, vranges)); + res.axis.push_back(new_var); + res.ranges.Set(new_var, range); + vranges.Set(new_var, range); + } + } + } + } + + // Add the original conditions (with variables substituted) to the resulting conditions + for (const Expr& old_cond : solved_system.as_conditions()) { + res.conditions.push_back(SuperSimplify(Substitute(old_cond, res.old_to_new), vranges)); + } + + return res; +} + +// Use the condition of a reduction op to simplify its domain (axis) +Expr SimplifyReductionDomain(const Expr& expr, const Map& outer_vranges) { + if (const Reduce* red = expr.as()) { + Map vranges = Merge(outer_vranges, IterVarsToMap(red->axis)); + auto res = SimplifyDomain(red->condition, IterVarsToVars(red->axis), + Merge(outer_vranges, IterVarsToMap(red->axis))); + + Array new_source; + for (const Expr& src : red->source) { + new_source.push_back(Substitute(src, res.old_to_new)); + } + + Array new_axis = IterVarsFromMap(res.axis, res.ranges, kCommReduce); + + // Perform simplification mainly to remove a possibly empty reduction. + return Simplify(Reduce::make(red->combiner, new_source, new_axis, + All(res.conditions), red->value_index)); + } else { + return expr; + } +} + +// Extract the given expr under the given condition as a separate tensor if the volume of the +// extracted tensor will be less than the volume of the outer_axis +Expr ExtractAsTensorMaybe(const Expr& e, const Expr& cond, + const Array& outer_axis, + const Map& vranges) { + // TODO(sgrechanik-h): We don't use divmod elimination here because of some performance problems + auto res = SimplifyDomain(cond, outer_axis, vranges, false); + + Expr new_expr = SuperSimplify(Substitute(e, res.old_to_new), vranges); + + // Keep only those variables of the new axis which are used in the new_expr + { + Array used_res_axis; + for (const Var& var : res.axis) { + if (ExprUseVar(new_expr, var)) { + used_res_axis.push_back(var); + } + } + + res.axis = std::move(used_res_axis); + } + + // Use the new axis to simplify the new expr, removing redundant inequalities + new_expr = SuperSimplify(new_expr, res.ranges); + + // If the expression does not use vars then it is probably better to keep it inlined + if (res.axis.empty()) { + return new_expr; + } + + // Compute volumes before and after + Expr old_volume = make_const(Int(64), 1); + for (const Var& var : outer_axis) { + old_volume = old_volume * vranges[var]->extent; + } + + Expr new_volume = make_const(Int(64), 1); + for (const Var& var : res.axis) { + new_volume = new_volume * res.ranges[var]->extent; + } + + // if we can prove that the old volume is not greater than the new volume then + // prefer the old expression. + if (CanProve(old_volume <= new_volume, vranges)) { + return e; + } + + Tensor tensor = op::TensorFromExpr(new_expr, IterVarsFromMap(res.axis, res.ranges), + "extracted_tensor"); + + Array args; + for (const Var& var : res.axis) { + args.push_back(res.new_to_old[var]); + } + + return Call::make(e.type(), tensor->op->name, args, + Call::CallType::Halide, tensor->op, tensor->value_index); +} + + +class RemoveRedundantInequalitiesMutator : public IRMutator { + public: + explicit RemoveRedundantInequalitiesMutator(Array known) { + for (const Expr& cond : known) { + known_.push_back(SuperSimplify(cond)); + } + } + + virtual Expr Mutate_(const Select* op, const Expr& e) { + bool has_side_effect = HasSideEffect(e); + Expr new_cond = SuperSimplify(Mutate(op->condition)); + if (is_one(new_cond) && !has_side_effect) { + return Mutate(op->true_value); + } else if (is_zero(new_cond) && !has_side_effect) { + return Mutate(op->false_value); + } else { + Array new_known = known_; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return Select::make(new_cond, new_mutator.Mutate(op->true_value), Mutate(op->false_value)); + } + } + + virtual Expr Mutate_(const Call* op, const Expr& e) { + if (op->name == intrinsic::tvm_if_then_else) { + Expr new_cond = SuperSimplify(Mutate(op->args[0])); + if (is_one(new_cond)) { + return Mutate(op->args[1]); + } else if (is_zero(new_cond)) { + return Mutate(op->args[2]); + } else { + Array new_known = known_; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return if_then_else(new_cond, new_mutator.Mutate(op->args[1]), Mutate(op->args[2])); + } + } else { + return IRMutator::Mutate_(op, e); + } + } + + virtual Expr Mutate_(const Reduce* op, const Expr& e) { + Array known_with_axes = known_; + for (const Expr& axis_cond : IterVarsToInequalities(op->axis)) { + known_with_axes.push_back(axis_cond); + } + RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes); + + Expr new_cond = mutator_with_axes.Mutate(op->condition); + + Array new_known = known_with_axes; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + + Array new_source; + for (const Expr& src : op->source) { + new_source.push_back(new_mutator.Mutate(src)); + } + + return Reduce::make(op->combiner, new_source, op->axis, new_cond, op->value_index); + } + + virtual Expr Mutate_(const EQ* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const NE* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const LT* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const LE* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const GT* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const GE* op, const Expr& e) { return MutateAtomic_(e); } + + virtual Expr Mutate_(const And* op, const Expr& e) { + return Mutate(op->a) && Mutate(op->b); + } + + private: + Expr MutateAtomic_(const Expr& e) { + Expr simplified = SuperSimplify(e); + for (const Expr& other : known_) { + if (Equal(simplified, other)) { + return const_true(); + } + } + return simplified; + } + + Array known_; +}; + +// Propagate information from conditions and remove redundant inequalities +// TODO(sgrechanik-h): This should be merged into standard simplifiers +Expr RemoveRedundantInequalities(const Expr& expr, const Array& known) { + return RemoveRedundantInequalitiesMutator(known).Mutate(expr); +} + +// Extract from cond an implication of cond not containing vars +std::pair ImplicationNotContainingVars( + const Expr& cond, const std::unordered_set& vars) { + CHECK(cond.type().is_bool()) << "The type of cond must be bool"; + // TODO(sgrechanik-h): not + if (const And* op = cond.as()) { + auto pair_a = ImplicationNotContainingVars(op->a, vars); + auto pair_b = ImplicationNotContainingVars(op->b, vars); + return {pair_a.first && pair_b.first, + pair_a.second && pair_b.second}; + } else if (const Or* op = cond.as()) { + auto pair_a = ImplicationNotContainingVars(op->a, vars); + auto pair_b = ImplicationNotContainingVars(op->b, vars); + return {Or::make(pair_a.first, pair_b.first), cond}; + } else if (!ExprUseVar(cond, vars)) { + return {cond, const_true()}; + } else { + return {const_true(), cond}; + } +} + +// Factor conditions out of a reduction by applying Fourier-Motzkin elimination and moving out +// (in)equalities which do not depend on the reduction variables. +std::pair LiftConditionsThroughReduction(const Expr& cond, + const Array& red_axis, + const Array& outer_axis) { + // Factor out atomics so that we can consider this as a system of inequalities + auto factoratomic_res = FactorOutAtomicFormulas(cond); + Array atomics = factoratomic_res.atomic_formulas; + const Expr& rest = factoratomic_res.rest; + + Array allvars; + for (const IterVar& v : red_axis) { + allvars.push_back(v->var); + } + for (const IterVar& v : outer_axis) { + allvars.push_back(v->var); + } + + auto vranges = Merge(IterVarsToMap(red_axis), IterVarsToMap(outer_axis)); + // start from reduction vars, so that input vars don't depend on them + atomics = SolveSystemOfInequalities(atomics, allvars, vranges).as_conditions(); + + // Append the rest part + Expr rewritten_cond = All(atomics) && rest; + + std::unordered_set vset; + for (const IterVar& v : red_axis) { + vset.insert(v->var.get()); + } + + // The outer (first) condition does not contain reduction vars, + // the inner (second) condition is everything else + return ImplicationNotContainingVars(rewritten_cond, vset); +} + +class ExtractReductionsMutator : public IRMutator { + public: + explicit ExtractReductionsMutator(const Array& outer_axis, + Map vranges, + std::string name = "extracted_reduction") + : outer_axis_(outer_axis), vranges_(std::move(vranges)), name_(std::move(name)) {} + + Expr Mutate_(const Reduce* op, const Expr& e) { + ExtractReductionsMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_), + Merge(vranges_, IterVarsToMap(op->axis)), + name_); + + Array new_source; + for (const Expr& src : op->source) { + new_source.push_back(new_mutator.Mutate(src)); + } + + Expr new_reduce = + Reduce::make(op->combiner, new_source, op->axis, op->condition, op->value_index); + + ExprFreeVarsVisitor fv_visitor; + fv_visitor.Visit(new_reduce); + + // Vars of the tensor we are going to create for this reduction + Array vars; + for (const Var& v : outer_axis_) { + // We take variables from the outer_axis_ which are also present in the new reduction + if (fv_visitor.free.count(v.get())) { + vars.push_back(v); + } + } + + auto newaxis_vmap_pair = CloneIterVars(IterVarsFromMap(vars, vranges_)); + Array new_axis = newaxis_vmap_pair.first; + new_reduce = SuperSimplify(Substitute(new_reduce, newaxis_vmap_pair.second), + IterVarsToMap(new_axis)); + + Tensor tensor = op::TensorFromExpr(new_reduce, new_axis, name_, tag_, attrs_); + + Array args; + for (const Var& v : vars) { + args.push_back(v); + } + + return Call::make(e.type(), tensor->op->name, args, + Call::CallType::Halide, tensor->op, tensor->value_index); + } + + private: + Array outer_axis_; + Map vranges_; + std::string name_; + std::string tag_; + Map attrs_; +}; + +// Extract reductions as separate tensors. +Expr ExtractReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges) { + return ExtractReductionsMutator(outer_axis, vranges).Mutate(expr); +} + +Expr ExtractNonTopReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges) { + if (const Reduce* red = expr.as()) { + Array new_outer_axis = Concat(IterVarsToVars(red->axis), outer_axis); + Map new_vranges = Merge(vranges, IterVarsToMap(red->axis)); + Array new_source; + for (const Expr& src : red->source) { + new_source.push_back(ExtractReductions(src, new_outer_axis, new_vranges)); + } + Expr new_condition = ExtractReductions(red->condition, new_outer_axis, new_vranges); + + return Reduce::make(red->combiner, new_source, red->axis, + new_condition, red->value_index); + } else { + return ExtractReductions(expr, outer_axis, vranges); + } +} + +Expr OptimizeAndLiftNonzeronessConditionsImpl(const Expr& expr, const Array& axis) { + Expr result; + + if (const Reduce* red = expr.as()) { + // TODO(sgrechanik-h): There are some other operations which behave like sum + bool is_sum = IsSumCombiner(red->combiner); + if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index)) { + Expr new_red = expr; + + // Here we simplify the reduction + { + Expr cond = red->condition; + Array source = red->source; + + // If it is a summation then we can lift nonzeroness conditions from the source + // and add them to the reduction conditions + if (is_sum) { + auto nz = NonzeronessCondition(red->source[red->value_index]); + cond = nz.cond && cond; + source.Set(0, nz.value); + } + + new_red = Reduce::make(red->combiner, source, red->axis, cond, red->value_index); + new_red = SimplifyReductionDomain(new_red, IterVarsToMap(axis)); + red = new_red.as(); + + // If the reduction disappears completely then transform the result as a non-reduction + if (!red) { + return OptimizeAndLiftNonzeronessConditionsImpl(new_red, axis); + } + } + + Expr new_outer_cond, new_reduce_cond; + Array new_source = red->source; + + // Partially lift conditions from the reduce condition + std::tie(new_outer_cond, new_reduce_cond) = + LiftConditionsThroughReduction(red->condition, red->axis, axis); + + // If it's not sum then we haven't yet lifted nonzeroness cond from the source + if (!is_sum) { + Expr outer_nz_cond, nz_cond, nz_source; + auto nz = NonzeronessCondition(red->source[red->value_index]); + // Append conditions from the reduction + nz_cond = new_reduce_cond && nz.cond; + nz_source = nz.value; + std::tie(outer_nz_cond, nz_cond) = + LiftConditionsThroughReduction(nz_cond, red->axis, axis); + new_outer_cond = new_outer_cond && outer_nz_cond; + new_source.Set(red->value_index, SelectElseZero(nz_cond, nz_source)); + } + + Expr new_reduce = Reduce::make(red->combiner, new_source, red->axis, + new_reduce_cond, red->value_index); + new_reduce = ExtractAsTensorMaybe(new_reduce, new_outer_cond, + IterVarsToVars(axis), IterVarsToMap(axis)); + result = SelectElseZero(new_outer_cond, new_reduce); + } else { + return SimplifyReductionDomain(expr, IterVarsToMap(axis)); + } + } else { + auto nz = NonzeronessCondition(expr); + Expr new_expr = ExtractAsTensorMaybe(nz.value, nz.cond, + IterVarsToVars(axis), IterVarsToMap(axis)); + result = SelectElseZero(nz.cond, new_expr); + } + + // Note that RemoveRedundantInequalities can sometimes propagate equalities which + // other simplifiers cannot, like (i % 3) == 0. + Array axis_conds = IterVarsToInequalities(axis); + result = RemoveRedundantInequalities(result, axis_conds); + + // Sometimes ExtractAsTensorMaybe doesn't perform extraction, so there may be some non-top + // reductions left, take care of them + Map vrange = IterVarsToMap(axis); + return SuperSimplify(ExtractReductions(result, IterVarsToVars(axis), vrange), + vrange); +} + +Tensor OptimizeAndLiftNonzeronessConditions(const Tensor& tensor) { + return op::TransformBody(tensor, OptimizeAndLiftNonzeronessConditionsImpl); +} + +TVM_REGISTER_API("ir_pass.IsSumCombiner") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = IsSumCombiner(args[0]); + }); + +TVM_REGISTER_API("ir_pass.CanFactorZeroFromCombiner") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = CanFactorZeroFromCombiner(args[0], args[1]); + }); + +TVM_REGISTER_API("ir_pass.LiftNonzeronessCondition") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LiftNonzeronessCondition(args[0]); + }); + +TVM_REGISTER_API("ir_pass.InlineTailCall") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = InlineTailCall(args[0]); + }); + +TVM_REGISTER_API("ir_pass.InlineTensors") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args[0].IsNodeType()) { + Expr e = args[0]; + if (args.size() == 1) { + *ret = InlineTensors(e); + } else if (args.size() == 2) { + *ret = InlineTensors(e, args[1]); + } else if (args.size() >= 3) { + *ret = InlineTensors(e, args[1], args[2]); + } + } else if (args[0].IsNodeType()) { + Tensor t = args[0]; + if (args.size() == 1) { + *ret = InlineTensors(t); + } else if (args.size() == 2) { + *ret = InlineTensors(t, args[1]); + } else if (args.size() >= 3) { + *ret = InlineTensors(t, args[1], args[2]); + } + } + }); + +TVM_REGISTER_API("ir_pass.SolveSystemOfInequalities") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SolveSystemOfInequalities(args[0], args[1], args[2]).as_conditions(); + }); + +TVM_REGISTER_API("ir_pass.SimplifyDomain") +.set_body([](TVMArgs args, TVMRetValue *ret) { + auto res = SimplifyDomain(args[0], args[1], args[2]); + Array axis = IterVarsFromMap(res.axis, res.ranges); + *ret = Array({All(res.conditions), axis, res.old_to_new, res.new_to_old}); + }); + +TVM_REGISTER_API("ir_pass.SimplifyReductionDomain") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SimplifyReductionDomain(args[0], args[1]); + }); + +TVM_REGISTER_API("ir_pass.ExtractAsTensorMaybe") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractAsTensorMaybe(args[0], args[1], args[2], args[3]); + }); + +TVM_REGISTER_API("ir_pass.ExtractReductions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractReductions(args[0], args[1], args[2]); + }); + +TVM_REGISTER_API("ir_pass.ExtractNonTopReductions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractNonTopReductions(args[0], args[1], args[2]); + }); + +TVM_REGISTER_API("ir_pass.OptimizeAndLiftNonzeronessConditions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = OptimizeAndLiftNonzeronessConditions(args[0]); + }); + +} // namespace ir +} // namespace tvm diff --git a/src/pass/zero_elimination.h b/src/pass/zero_elimination.h new file mode 100644 index 000000000000..bcfba038cc02 --- /dev/null +++ b/src/pass/zero_elimination.h @@ -0,0 +1,239 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file zero_elimination.h + * \brief Transform tensors in such a way as to eliminate summation over zeros. + */ +#ifndef TVM_PASS_ZERO_ELIMINATION_H_ +#define TVM_PASS_ZERO_ELIMINATION_H_ + +#include +#include + +#include + +namespace tvm { +namespace ir { + +/*! + * \brief Clone the reduction by cloning its iteration variables. + */ +Expr CloneReduction(const Expr& expr); + +/*! + * \brief Check if the given combiner represents summation. + */ +EXPORT bool IsSumCombiner(const CommReducer& combiner); + +/*! + * \brief Check if zero may be factored out of a reduction with this combiner when it is in + * the \p value_index position. + * + * For example, if the combiner works on tuples of two elements and `value_index = 1`, + * check that `(a, 0) combine (b, 0) = (c, 0)` for any a, b and some c. + * Note that all combiners generated by autodiff have this property. + */ +EXPORT bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index); + +/*! + * \brief Transform the expression into `c ? e : 0`, that is lift the condition of being + * possible to be non-zero to the top level. + */ +EXPORT Expr LiftNonzeronessCondition(const Expr& expr); + +/*! + * \brief If the body of the tensor consists of a single tensor call (indexing) expression, + * inline it. + */ +EXPORT Tensor InlineTailCall(const Tensor& tensor); + +/*! + * \brief Inline tensors recursively. + * + * This function will inline tensors recursively until it reaches a tensor which is impossible to + * inline (a reduction if \p inline_reductions is false, a non-compute tensor, a tensor which is + * not from \p inlineable). It won't descend into non-inlinable tensors' bodies. + * + * \param expr The expression to transform. + * \param inlineable A list of tensors which are allowed to be inlined. If empty, try + * to inline all tensors. + * \param inline_reductions Whether to inline reductions (this may result in top-level reduction + * nodes). + */ +EXPORT Expr InlineTensors(const Expr& expr, + const Array& inlineable = Array(), + bool inline_reductions = false); + +/*! + * \brief Inline tensors recursively. + * + * This function will inline tensors recursively until it reaches a tensor which is impossible to + * inline (a reduction if \p inline_reductions is false, a non-compute tensor, a tensor which is + * not from \p inlineable). It won't descend into non-inlinable tensors' bodies. + * + * \param tensor The tensor whose body to transform. + * \param inlineable A list of tensors which are allowed to be inlined. If empty, try + * to inline all tensors. + * \param inline_reductions Whether to inline reductions (this may result in top-level reduction + * nodes). + */ +EXPORT Tensor InlineTensors(const Tensor& tensor, + const Array& inlineable = Array(), + bool inline_reductions = false); + + +/*! + * \brief A struct representing a set of inequalities describing bounds of a variable. + * + * Given a variable x, this struct represents the following (in)equalities: + * - `coef*x >= low` for each `low` in `lower` + * - `coef*x == eq` for each `eq` in `equal` + * - `coef*x <= upp` for each `upp` in `upper` + * + * Note that every array is supposed to be sorted in the order of increasing expression + * complexity. + */ +struct VarBounds { + Expr coef; + Array lower; + Array equal; + Array upper; + + /*! + * \brief Perform substitution on all components of the struct. + */ + VarBounds substitute(const Map& subst) const; +}; + +/*! + * \brief A struct representing a system of inequalities resulted from Fourier-Motzkin elimination. + */ +struct SolveSystemOfInequalitiesResult { + Array variables; + std::unordered_map bounds; + Array other_conditions; + + /*! + * \brief Combine the information into an array of (in)equalities. + */ + Array as_conditions() const; +}; + +/*! + * \brief Rewrite the system of inequalities using Fourier-Motzkin elimination. + * + * This function takes an array of (in)equalities and an array of variables, and essentially + * rewrites the (in)equalities into an array of (in)equalities of the following form: + * + * x0 >= f0(x1, x2, ..., xn) + * x0 <= g0(x1, x2, ..., xn) + * x1 >= f1(x2, ..., xn) + * x1 <= g1(x2, ..., xn) + * ... + * xn >= fn() // just a constant + * xn <= gn() // just a constant + * + * This array is represented in a more structural way using SolveSystemOfInequalitiesResult. + * + * Note that the algorithm is extremely slow, it is super-exponential, so please provide variable + * ranges to aid the removal of redundant inequalities. + * + * \param inequalities The original (in)equalities. + * \param variables The variables x0, ..., xn + * \param vranges A map from variables to the corresponding value ranges. Extremely important for + * efficiency. + */ +EXPORT SolveSystemOfInequalitiesResult SolveSystemOfInequalities( + const Array& inequalities, const Array& variables, const Map& vranges); + +/*! + * \brief A struct representing a result of domain simplification. It is basically + * a new array of variables, the information about their ranges, and a new condition together with + * substitutions from the old variables to the new ones and from the new ones to the old ones. + */ +struct DomainSimplificationResult { + Array conditions; + Array axis; + Map ranges; + Map old_to_new; + Map new_to_old; +}; + +/*! + * \brief Simplify an iteration domain. + * + * An iteration domain is basically an array of variables and a condition. The function will do the + * following: + * - Replace div and mod operations with new variables (optional). + * - Extract (in)equalities from the condition. + * - Perform Fourier-Motzkin elimination. + * - Shear the domain of iteration (e.g. if `y <= x <= y + 2` then x will be replaced with `y + d` + * where `d` is a new variable such that `0 <= d <= 2`). + * - Remove redundant variables. + * - Infer new variable ranges (hopefully more precise). + * + * \param cond The condition of the original domain. + * \param axis The variables of the original domain. + * \param vranges A map from variables (both domain and outer) to their value ranges. + * \param eliminate_div_mod Whether to eliminate div and mod by introducing new variables. + */ +EXPORT DomainSimplificationResult SimplifyDomain(const Expr& cond, + const Array& axis, + Map vranges, + bool eliminate_div_mod = true); + + +/*! + * \brief Simplify the iteration domain of a reduction expression using SimplifyDomain. + */ +EXPORT Expr SimplifyReductionDomain(const Expr& expr, const Map& outer_vranges); + +/*! + * \brief Extract the given expression under the given condition as a separate tensor if the volume + * of the extracted tensor will be less than the volume of the \p outer_axis. + * + * \param expr The expression to extract. + * \param cond A condition which is assumed to be true. + * \param outer_axis Some variables, usually input variables of the enclosing tensor. + * \param vranges Information about ranges of variables. + * \return Either a call to an extracted tensor or the original expression. + */ +EXPORT Expr ExtractAsTensorMaybe(const Expr& expr, const Expr& cond, + const Array& outer_axis, + const Map& vranges); + +/*! + * \brief Extract reductions as separate tensors. This may be needed when non-top-level reductions + * are created. + * + * \param expr The expression from which to extract reductions. + * \param outer_axis Input variables of the enclosing tensor. + * \param vranges Information about ranges of variables. + * \return An expression without non-top-level reductions. + */ +EXPORT Expr ExtractReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges); + +/*! + * \brief Extract reductions as separate tensors, but if the expr itself is a reduction, leave it + * intact. + * + * \param expr The expression from which to extract reductions. + * \param outer_axis Input variables of the enclosing tensor. + * \param vranges Information about ranges of variables. + * \return An expression without non-top-level reductions. + */ +EXPORT Expr ExtractNonTopReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges); + +/*! + * \brief Perform lifting of conditions of being possible to be non-zero together with + * applying some transformations like simplifying the reduction domain. Works only with + * this particular tensor's body, i.e. doesn't perform inlining. + */ +EXPORT Tensor OptimizeAndLiftNonzeronessConditions(const Tensor& tensor); + +} // namespace ir +} // namespace tvm +#endif // TVM_PASS_ZERO_ELIMINATION_H_ diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py new file mode 100644 index 000000000000..a1d4070a72f7 --- /dev/null +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -0,0 +1,464 @@ +import random +import sys +import numpy as np +import tvm +from tvm import comm_reducer +from tvm.testing import estimate_performance +from tvm.ir_pass import Simplify, Equal, LiftNonzeronessCondition, IsSumCombiner, \ + CanFactorZeroFromCombiner, InlineTailCall, InlineTensors, SolveSystemOfInequalities, \ + SimplifyDomain, SimplifyReductionDomain, ExtractAsTensorMaybe, ExtractReductions, \ + ExtractNonTopReductions, OptimizeAndLiftNonzeronessConditions + +def get_shape(tensor): + return [s.value for s in tensor.shape] + +def check_eq(t1, t2, args): + s1 = tvm.create_schedule(t1.op) + m1 = tvm.build(s1, [t1] + args) + + s2 = tvm.create_schedule(t2.op) + m2 = tvm.build(s2, [t2] + args) + + for _ in range(5): + arg_vals = [tvm.ndarray.array(np.random.uniform(-10, 10, size=get_shape(a)) + .astype(a.dtype)) + for a in [t1] + args] + m1(*arg_vals) + res1 = arg_vals[0].asnumpy() + m2(*arg_vals) + res2 = arg_vals[0].asnumpy() + + np.testing.assert_allclose(res1, res2, atol=1e-3, rtol=1e-2) + +def check_symeq(expr1, expr2): + expr1 = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr1)) + expr2 = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr2)) + + if tvm.ir_pass.Equal(expr1, expr2): + return + + diff = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr1 - expr2)) + if not Equal(diff, tvm.const(0, expr1.dtype)): + raise AssertionError("Expressions {} and {} are not equal, their diff is {}" + .format(expr1, expr2, diff)) + +def compute(shape, fcompute): + """Like tvm.compute but automatically extracts reductions.""" + return tvm.compute(shape, + lambda *vs: ExtractNonTopReductions( + fcompute(*vs), vs, {v: tvm.Range(0, s) for v, s in zip(vs, shape)})) + +def check_tensor_symeq(A, B): + if not isinstance(B, tvm.tensor.Tensor): + B = compute(A.shape, B) + vmap = {a.var: b.var for a, b in zip(A.op.axis, B.op.axis)} + expr_a = tvm.ir_pass.Substitute(A.op.body[A.value_index], vmap) + expr_b = B.op.body[B.value_index] + expr_a = tvm.ir_pass.CanonicalSimplify(InlineTensors(expr_a, [], True)) + expr_b = tvm.ir_pass.CanonicalSimplify(InlineTensors(expr_b, [], True)) + if not Equal(expr_a, expr_b): + print(expr_a) + print(expr_b) + raise AssertionError("The expressions are not equal") + +def check_eq_bruteforce(expr1, expr2, vranges): + def _compute_body(*us): + vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} + return tvm.ir_pass.Substitute(expr1 == expr2, vmap) + + A = compute([r.extent.value for v, r in vranges.items()], _compute_body) + args = [tvm.ndarray.empty(A.shape, A.dtype)] + sch = tvm.create_schedule(A.op) + mod = tvm.build(sch, [A]) + mod(*args) + res = args[0].asnumpy() + if not np.all(res): + indices = list(np.argwhere(res == 0)[0]) + counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] + counterex = ", ".join([v + " = " + str(i) for v, i in sorted(counterex)]) + raise AssertionError("Expressions {}\nand {}\nare not equal on {}\n" + "Counterexample: {}" + .format(expr1, expr2, vranges, counterex)) + +prod_combiner = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0)) +sum_combiner = comm_reducer(lambda x, y: x + y, lambda t0: tvm.const(0, t0)) +sum2_combiner = comm_reducer(lambda x, y: y + x, lambda t0: tvm.const(0, t0)) +sum_derivative_combiner = comm_reducer(lambda x, y: (x[0] + y[0], y[1] + x[1]), + lambda t0, t1: (tvm.const(0, t0), tvm.const(0, t1))) +prod_derivative_combiner = comm_reducer(lambda x, y: (x[0]*y[0], x[0]*y[1] + x[1]*y[0]), + lambda t0, t1: (tvm.const(1, t0), tvm.const(0, t1))) +sum_both_combiner = comm_reducer(lambda x, y: (x[0] + y[0], x[0] + y[0] + x[1] + y[1]), + lambda t0, t1: (tvm.const(0, t0), tvm.const(0, t1))) +xor_combiner = comm_reducer(lambda x, y: x ^ y, lambda t0: tvm.const(0, t0)) + +def test_is_sum_combiner(): + k = tvm.reduce_axis((0, 10), name="k") + i = tvm.const(0, "int32") + f = tvm.const(0.0, "float32") + assert IsSumCombiner(sum_combiner(i, k).combiner) + assert IsSumCombiner(sum_combiner(f, k).combiner) + assert IsSumCombiner(sum2_combiner(i, k).combiner) + assert IsSumCombiner(sum2_combiner(f, k).combiner) + assert not IsSumCombiner(sum_derivative_combiner((f, f), k)[0].combiner) + assert not IsSumCombiner(prod_combiner(f, k).combiner) + assert not IsSumCombiner(prod_derivative_combiner((f, f), k)[1].combiner) + +def test_can_factor_zero_from_combiner(): + k = tvm.reduce_axis((0, 10), name="k") + i = tvm.const(0, "int32") + f = tvm.const(0.0, "float32") + assert CanFactorZeroFromCombiner(sum_combiner(i, k).combiner, 0) + assert CanFactorZeroFromCombiner(sum2_combiner(f, k).combiner, 0) + assert CanFactorZeroFromCombiner(sum_derivative_combiner((f, f), k)[0].combiner, 0) + assert CanFactorZeroFromCombiner(sum_derivative_combiner((f, f), k)[0].combiner, 1) + assert not CanFactorZeroFromCombiner(prod_derivative_combiner((f, f), k)[0].combiner, 0) + assert CanFactorZeroFromCombiner(prod_derivative_combiner((f, f), k)[0].combiner, 1) + assert CanFactorZeroFromCombiner(sum_both_combiner((f, f), k)[0].combiner, 0) + assert not CanFactorZeroFromCombiner(sum_both_combiner((f, f), k)[0].combiner, 1) + +def test_lift_nonzeroness_condition(): + k = tvm.reduce_axis((0, 5), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 5), name="n") + A = tvm.placeholder((10,), name='A') + + def _check(shape, fun, A=A): + T1 = tvm.compute(shape, fun) + T2 = tvm.compute(shape, lambda *args: LiftNonzeronessCondition(fun(*args))) + check_eq(T1, T2, [A]) + assert isinstance(T2.op.body[0], tvm.expr.Select) + + _check((10,), lambda i: A[i]) + _check((10,), lambda i: A[i] + (i % 2 == 0)) + _check((10,), lambda i: A[i]*(i % 2 == 0) + (i % 2 == 0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), A[i], 0.0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), A[i], 0.0) + (i % 2 == 0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), 0.0, A[i]) + (i % 2 == 0)) + def e1(i): return tvm.expr.Select((i % 2 == 1), 0.0, A[i]) + def e2(i): return tvm.expr.Select((i % 2 == 0), A[(i + 1) % 10], 0.0) + def e3(i): return tvm.expr.Select((i % 2 == 1), A[i], 0.0) + _check((10,), lambda i: e1(i) + e2(i) + e3(i) + e1(i)*e2(i)) + _check((10,), lambda i: e1(i)*e3(i)) + _check((10,), lambda i: e1(i)*e2(i)) + _check((10,10), lambda i, j: A[i]*(i == j) + A[j]*(i == 2*j) + A[j]*(j == i)) + _check((10,10), lambda i, j: tvm.min(A[i]*(i == j), A[j]*(i == 2*j))) + _check((10,10), lambda i, j: tvm.max(A[i]*(i == j), A[j]*(i == 2*j))) + _check((10,10), lambda i, j: A[i]*(i == j) - A[j]*(i == 2*j)) + _check((10,10), lambda i, j: A[i]*(i == j) / (1 + tvm.abs(A[j]*(i == 2*j)))) + _check((10,10), lambda i, j: i*(i < j) + j*(i > j)) + _check((10,10), lambda i, j: i*(i < j) % (1 + j*(i > j))) + + def _check_symeq(expr1, expr2): + expr1 = LiftNonzeronessCondition(expr1) + expr2 = LiftNonzeronessCondition(expr2) + print(expr1) + print(expr2) + print() + check_symeq(expr1, expr2) + + _check_symeq(tvm.expr.Select(tvm.expr.EQ(k, l), 0.0, tvm.expr.Cast('float32', (k < n))), + tvm.expr.Select(tvm.expr.And((k < n), tvm.expr.NE(k, l)), 1.0, 0.0)) + _check_symeq(tvm.min(tvm.expr.Cast('int32', k < n)*l, tvm.expr.Select(k >= n, 0, 1)), + tvm.expr.Select(k < n, tvm.min(l, 1), 0)) + +def test_inline_tail_call(): + A = tvm.compute((10, 10), lambda i, j: i + j*j) + B = tvm.compute((5, 6), lambda k, l: A[k + l, k + 1]) + C = InlineTailCall(B) + resbody = lambda k, l: k + l + (k + 1)*(k + 1) + check_symeq(C.op.body[0], resbody(*[iv.var for iv in C.op.axis])) + +def test_inline_tensors(): + A = tvm.compute((10, 10), lambda i, j: i + j) + B = tvm.compute((10, 10), lambda i, j: i * j) + C = tvm.compute((10, 10), lambda i, j: A[i, j] + B[i, j]) + k = tvm.reduce_axis((0, 5), name="k") + D = tvm.compute((10, 10), lambda i, j: tvm.sum(A[i, k], k)) + E = tvm.compute((10, 10), lambda i, j: A[2, j] + C[i, 2] + D[i, j]) + + R = InlineTensors(E) + resbody = lambda i, j: 2 + j + i + 2 + i*2 + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [A]) + resbody = lambda i, j: 2 + j + C[i, 2] + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [A, C]) + resbody = lambda i, j: 2 + j + ((i + 2) + B[i, 2]) + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [B, C]) + resbody = lambda i, j: A[2, j] + (A[i, 2] + i*2) + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + +def test_solve_system_of_inequalities(): + seed = random.randrange(sys.maxsize) + print("\nseed: {}\n".format(seed)) + random.seed(seed) + + def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): + vs = [tvm.var("x" + str(i)) for i in range(variables)] + + fs = [] + for i in range(formulas): + s1 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s1 += random.randint(coef[0], coef[1]) + s2 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s2 += random.randint(coef[0], coef[1]) + op = random.choice([tvm.expr.EQ, tvm.expr.LE, tvm.expr.LT, tvm.expr.GE, tvm.expr.GT]) + fs.append(op(s1, s2)) + + vranges = {v: tvm.Range(bounds[0], bounds[1] + 1) for v in vs} + + before = tvm.all(*fs) + print(before) + after = tvm.all(*SolveSystemOfInequalities(fs, vs, vranges)) + print(after) + print() + + check_eq_bruteforce(before, after, vranges) + + for i in range(3): + _check(1, 1) + for i in range(3): + _check(1, 2) + + for i in range(3): + _check(2, 1) + for i in range(3): + _check(2, 2) + for i in range(3): + _check(2, 3) + + # Somewhere here coefficients in the results become too large, leading to overflow, + # so we use smaller initial coefficients + + for i in range(5): + _check(3, 3, coef=(-2,2)) + for i in range(5): + _check(3, 4, coef=(-2,2)) + + for i in range(5): + _check(4, 3, coef=(-1,1)) + + for i in range(5): + _check(10, 2, coef=(-1,1), bounds=(0, 4)) + for i in range(5): + _check(10, 3, coef=(0,1), bounds=(0, 4)) + +def test_simplify_domain(): + # Note that here we test both SimplifyDomain and SimplifyReductionDomain. + def _check(cond, axis, volume, vranges={}): + vranges_with_axis = dict(vranges) + vranges_with_axis.update({iv.var: iv.dom for iv in axis}) + variables = [iv.var for iv in axis] + new_cond, new_axis, old_to_new, new_to_old = SimplifyDomain(cond, variables, + vranges_with_axis) + + print("old", axis, cond) + print("new", new_axis, new_cond) + print("old_to_new", old_to_new) + print("new_to_old", new_to_old) + print() + + cond_subst = tvm.ir_pass.Substitute(cond, old_to_new) + new_vranges = vranges.copy() + new_vranges.update({v.var: v.dom for v in new_axis}) + # If new_cond is true in the new domain, then cond_subst must also be true in the new + # domain, but the reverse is not necessarily true + check_eq_bruteforce(tvm.all(new_cond, cond_subst), new_cond, new_vranges) + + new_cond_subst = tvm.ir_pass.Substitute(new_cond, new_to_old) + old_vranges = vranges.copy() + old_vranges.update({v.var: v.dom for v in axis}) + check_eq_bruteforce(cond, tvm.all(cond, new_cond_subst), old_vranges) + + # Also check SimplifyReductionDomain + reduction = xor_combiner(sum([v*(i + 1) for i, v in enumerate(axis)]), axis) + new_reduction = SimplifyReductionDomain(reduction, vranges) + check_eq_bruteforce(reduction, new_reduction, vranges) + + vol = np.prod([iv.dom.extent.value for iv in new_axis]) + if vol != volume: + raise AssertionError("New volume is {} != {}\n" + "Old domain {} where {}\nNew domain {} where {}" + .format(vol, volume, axis, cond, new_axis, new_cond)) + + k = tvm.reduce_axis((0, 5), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 5), name="n") + + _check((k <= l), [k, l, n], 125) + _check((k < l), [k, l, n], 80) + _check(tvm.expr.EQ(k, l), [k, l, n], 25) + _check(tvm.all(tvm.expr.EQ(k, l), (l < n)), [k, l, n], 16) + _check(tvm.expr.EQ(2*l, k), [k, l, n], 15) + # TODO: the result depends on the order of variables because we don't have a proper solver for + # systems of linear equations yet + _check(tvm.expr.EQ(2*l, k), [n, l, k], 25) + _check(tvm.all(l - k < 2, 2*n == k), [k, l, n], 15) + _check(tvm.all(l - k < 2, l >= k), [k, l, n], 50) + + some_var = tvm.var('some_var') + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 50, {some_var: tvm.Range(0, 3)}) + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 25, {some_var: tvm.Range(0, 2)}) + + + k = tvm.reduce_axis((-3, 2), name="k") + l = tvm.reduce_axis((-3, 2), name="l") + n = tvm.reduce_axis((-3, 2), name="n") + + _check((k < l), [k, l, n], 80) + _check(tvm.expr.EQ(k, l), [k, l, n], 25) + _check(tvm.all(tvm.expr.EQ(k, l), (l < n)), [k, l, n], 16) + # Now there are only two possible values for l: {l = -1, k = -2} and {l = 0, k = 0} + _check(tvm.expr.EQ(2*l, k), [k, l, n], 10) + # TODO: the result depends on the order of variables because we don't have a proper solver for + # systems of linear equations + _check(tvm.expr.EQ(2*l, k), [n, l, k], 25) + _check(tvm.all(l - k < 2, 2*n == k), [k, l, n], 10) + _check(tvm.all(l - k < 2, l >= k), [k, l, n], 50) + + some_var = tvm.var('some_var') + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 50, {some_var: tvm.Range(0, 3)}) + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 25, {some_var: tvm.Range(0, 2)}) + + + k = tvm.reduce_axis((0, 6), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 30), name="n") + + _check(tvm.all(k + l*6 == n), [k, l, n], 30) + _check(tvm.all(k + l*6 == n), [n, k, l], 30) + _check(tvm.all(k + l*6 == n), [n, l, k], 30) + + _check(tvm.all(n / 5 == k, n % 5 == l), [l, k, n], 30) + # TODO: Same thing with the order + _check(tvm.all(n / 5 == k, n % 5 == l), [n, l, k], 30) + + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + # TODO: This is not fully optimized because we don't have a solver + _check(tvm.all((l + k)%3 <= 1, (l + k)/3 <= 2), [l, k], 144) + +def test_extract_as_tensor_maybe(): + def _check(shape, fcompute, volume=None, vranges={}): + def fcompute_extracted(*variables): + vranges_updated = dict(vranges) + vranges_updated.update({v: tvm.Range(0, s) for v, s in zip(variables, shape)}) + expr = fcompute(*variables) + if isinstance(expr, tvm.expr.Select): + new_true_value = ExtractAsTensorMaybe(expr.true_value, + expr.condition, + variables, + vranges_updated) + expr = tvm.expr.Select(expr.condition, + new_true_value, + expr.false_value) + if volume is not None: + assert isinstance(new_true_value, tvm.expr.Call) + vol = np.prod([iv.dom.extent.value for iv in new_true_value.func.axis]) + if vol != volume: + raise AssertionError("New volume is {} != {}" + .format(vol, volume)) + return expr + + A = tvm.compute(shape, fcompute) + B = tvm.compute(shape, fcompute_extracted) + check_eq(A, B, []) + + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, i + j, 0), volume=30) + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, j, 0), volume=10) + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, i, 0), volume=3) + _check((10, 10), lambda i, j: tvm.expr.Select(tvm.all(i < j, j < 5), i + j, 0), volume=16) + # This one doesn't get extracted + _check((10, 10), lambda i, j: tvm.expr.Select(i <= j, i + j, 0)) + +def test_extract_reductions(): + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + n = tvm.reduce_axis((0, 10), name="n") + + A = tvm.compute((10, 10), + lambda i, j: + ExtractReductions(sum_combiner(i + k + xor_combiner(j*k + l, l), k), + [i, j], + {i: tvm.Range(0, 10), j: tvm.Range(0, 10)})) + B = tvm.compute((10, 10), lambda j, k: xor_combiner(j*k + l, l)) + C = tvm.compute((10, 10), lambda i, j: sum_combiner(i + k + B[j, k], k)) + check_eq(C, A, []) + + fcompute = lambda i, j: \ + ExtractReductions(sum_both_combiner((prod_derivative_combiner((i*n + 2*k, j + k), k)[1], + xor_combiner(j*n + l, l)), n)[1], + [i, j], + {i: tvm.Range(0, 10), j: tvm.Range(0, 10)}) + A = tvm.compute((10, 10), fcompute) + _, B = tvm.compute((10, 10, 10), + lambda i, j, n: prod_derivative_combiner((i*n + 2*k, j + k), k)) + C = tvm.compute((10, 10), lambda j, n: xor_combiner(j*n + l, l)) + _, D = tvm.compute((10, 10), lambda i, j: sum_both_combiner((B[i, j, n], C[j, n]), n)) + check_eq(A, D, []) + +def test_optimize_and_lift_nonzeroness(): + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + n = tvm.reduce_axis((0, 10), name="n") + A = tvm.placeholder((10, 10), name="A") + + zero = tvm.const(0, 'float32') + + B = compute((10, 10), lambda i, j: tvm.sum((i == j)*A[i, k] + A[k, j]*(i == j), k)) + B = OptimizeAndLiftNonzeronessConditions(B) + R = lambda i, j: tvm.expr.Select(i == j, + tvm.sum(A[j, k] + A[k, j], k), + zero) + check_tensor_symeq(B, R) + + # TODO: This test is unstable: sometimes the resulting condition looks like + # (i == j)*(j == i) instead of (i == j) + # B = compute((10, 10), lambda i, j: tvm.sum((i == j)*(i == k)*A[i, k] + + # (i == j)*A[k, j]*(i == k), k)) + # B = OptimizeAndLiftNonzeronessConditions(B) + # R = lambda i, j: tvm.expr.Select(i == j, A[j, j]*2.0, zero) + # check_tensor_symeq(B, R) + + B = compute((10, 10), lambda i, j: tvm.sum((i < j)*(j < k)*A[j, k], k)) + B = OptimizeAndLiftNonzeronessConditions(B) + k1 = tvm.reduce_axis((2, 10), name="k1") + R = compute((10, 10), lambda i, j: + tvm.expr.Select(tvm.all(i < j, j < 10), + tvm.sum(tvm.expr.Select(j < k1, A[j, k1], zero), k1), + zero)) + check_eq(B, R, [A]) + assert estimate_performance(B) <= estimate_performance(R) + + # TODO: This one needs the equation solver + # B = compute((10, 10), lambda i, j: tvm.sum((i <= j)*(j <= k)*A[j, k], k, where=(i >= k))) + # B = OptimizeAndLiftNonzeronessConditions(B) + # R = compute((10, 10), lambda i, j: tvm.expr.Select((i == j), A[i, i], zero)) + # check_eq(B, R, [A]) + # assert estimate_performance(B) <= estimate_performance(R) + + B = compute((10, 10), + lambda i, j: prod_derivative_combiner((A[j, k], (i <= j)*(j < k)*A[i, k]), k)[1]) + B = OptimizeAndLiftNonzeronessConditions(B) + R = compute((10, 10), lambda i, j: + tvm.expr.Select(tvm.all(i <= j, j < 10), + prod_derivative_combiner((A[j, k], (j < k)*A[i, k]), k)[1], + zero)) + check_eq(B, R, [A]) + assert estimate_performance(B) <= estimate_performance(R) + +if __name__ == "__main__": + test_is_sum_combiner() + test_can_factor_zero_from_combiner() + test_lift_nonzeroness_condition() + test_inline_tail_call() + test_inline_tensors() + test_solve_system_of_inequalities() + test_simplify_domain() + test_extract_as_tensor_maybe() + test_extract_reductions() + test_optimize_and_lift_nonzeroness() From 1d7b561afa7b30204c9c25db06ca5c128afcdd2a Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Wed, 27 Feb 2019 17:39:31 +0300 Subject: [PATCH 02/10] Fix several Or-related issues --- src/pass/zero_elimination.cc | 39 ++++++++----------- .../unittest/test_pass_zero_elimination.py | 11 ++++++ 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index 56e4006c824a..352ca4400abc 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -367,7 +367,11 @@ class NonzeronessConditionFunctor : public ExprFunctor { public: NonzeronessConditionResult NonzeronessCondition(const Expr& e) { - return VisitExpr(e, e); + if (e.type().is_bool()) { + return {e, const_true()}; + } else { + return VisitExpr(e, e); + } } result_type VisitExpr_(const Variable*, const Expr& e) final { return Default_(e); } @@ -382,13 +386,6 @@ class NonzeronessConditionFunctor result_type VisitExpr_(const Mod* op, const Expr& e) final { return BinOpDivLike_(op, e); } result_type VisitExpr_(const Min* op, const Expr& e) final { return BinOpAddLike_(op, e); } result_type VisitExpr_(const Max* op, const Expr& e) final { return BinOpAddLike_(op, e); } - result_type VisitExpr_(const EQ* op, const Expr& e) final { return Bool_(op, e); } - result_type VisitExpr_(const NE* op, const Expr& e) final { return Bool_(op, e); } - result_type VisitExpr_(const LE* op, const Expr& e) final { return Bool_(op, e); } - result_type VisitExpr_(const LT* op, const Expr& e) final { return Bool_(op, e); } - result_type VisitExpr_(const GE* op, const Expr& e) final { return Bool_(op, e); } - result_type VisitExpr_(const GT* op, const Expr& e) final { return Bool_(op, e); } - result_type VisitExpr_(const Not* op, const Expr& e) final { return Bool_(op, e); } result_type VisitExpr_(const Cast* op, const Expr& e) final { if (op->value.type().is_bool()) { @@ -445,9 +442,7 @@ class NonzeronessConditionFunctor return {new_cond, nz_b.value}; } - Expr new_cond = - SuperSimplify(Or::make(cond && nz_a.cond, - !cond && nz_b.cond)); + Expr new_cond = SuperSimplify((cond && nz_a.cond) || (!cond && nz_b.cond)); if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { return {new_cond, e}; } else { @@ -467,7 +462,7 @@ class NonzeronessConditionFunctor return {nz_a.cond, TNode::make(nz_a.value, nz_b.value)}; } } else { - Expr new_cond = SuperSimplify(Or::make(nz_a.cond, nz_b.cond)); + Expr new_cond = SuperSimplify(nz_a.cond || nz_b.cond); Expr new_a = Equal(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr(); Expr new_b = Equal(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr(); Expr new_expr = TNode::make(new_a, new_b); @@ -499,11 +494,6 @@ class NonzeronessConditionFunctor return {nz_a.cond, TNode::make(nz_a.value, op->b)}; } } - - template - NonzeronessConditionResult Bool_(const TNode* op, const Expr& e) { - return {e, make_const(e.type(), 1)}; - } }; NonzeronessConditionResult NonzeronessCondition(const Expr& expr) { @@ -628,7 +618,7 @@ class FactorOutAtomicFormulasFunctor res_a.atomic_formulas = std::move(new_cond_a); res_b.atomic_formulas = std::move(new_cond_b); - Expr new_rest = Or::make(res_a.to_expr(), res_b.to_expr()); + Expr new_rest = res_a.to_expr() || res_b.to_expr(); return {res, new_rest}; } @@ -670,11 +660,11 @@ class EliminateDivModMutator : public IRMutator { if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { return var_pair_opt.value().first; } else { - return Div::make(mutated_a, Mutate(op->b)); + return mutated_a / Mutate(op->b); } } - return Div::make(Mutate(op->a), Mutate(op->b)); + return Mutate(op->a) / Mutate(op->b); } virtual Expr Mutate_(const Mod* op, const Expr& e) { @@ -689,11 +679,11 @@ class EliminateDivModMutator : public IRMutator { if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { return var_pair_opt.value().second; } else { - return Mod::make(mutated_a, Mutate(op->b)); + return mutated_a % Mutate(op->b); } } - return Mod::make(Mutate(op->a), Mutate(op->b)); + return Mutate(op->a) % Mutate(op->b); } private: @@ -1427,7 +1417,10 @@ std::pair ImplicationNotContainingVars( } else if (const Or* op = cond.as()) { auto pair_a = ImplicationNotContainingVars(op->a, vars); auto pair_b = ImplicationNotContainingVars(op->b, vars); - return {Or::make(pair_a.first, pair_b.first), cond}; + return {pair_a.first || pair_b.first, + (pair_a.first || pair_b.second) && + (pair_b.first || pair_a.second) && + (pair_a.second || pair_b.second)}; } else if (!ExprUseVar(cond, vars)) { return {cond, const_true()}; } else { diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py index a1d4070a72f7..9f8c110b7197 100644 --- a/tests/python/unittest/test_pass_zero_elimination.py +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -451,6 +451,17 @@ def test_optimize_and_lift_nonzeroness(): check_eq(B, R, [A]) assert estimate_performance(B) <= estimate_performance(R) + B = compute((10,), lambda i: + tvm.sum(A[i, k]*tvm.any(tvm.all(i < 5, k < 6), tvm.all(i > 5, k > 4)), k)) + B = OptimizeAndLiftNonzeronessConditions(B) + R = compute((10,), lambda i: + tvm.expr.Select(tvm.any(i < 5, i > 5), + tvm.sum(A[i, k], k, where=tvm.all(tvm.any(i < 5, k > 4), + tvm.any(i > 5, k < 6))), + zero)) + check_eq(B, R, [A]) + assert estimate_performance(B) <= estimate_performance(R) + if __name__ == "__main__": test_is_sum_combiner() test_can_factor_zero_from_combiner() From cdd9c814a016de3b4837d54ac66bdf225d1fb87a Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Wed, 27 Feb 2019 18:22:01 +0300 Subject: [PATCH 03/10] Implement is_const_int via is_const_value --- include/tvm/ir_operator.h | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index 09a046228b5f..b336caac3e96 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -513,22 +513,6 @@ inline bool is_negative_const(const Expr& a) { } } -inline bool is_const_int(const Expr& x, int64_t value) { - if (const auto* op = x.as()) { - return op->value == value; - } else if (const auto* op = x.as()) { - return op->value == static_cast(value); - } else if (const auto* op = x.as()) { - const Expr& val = op->value; - if (const auto* opv = val.as()) { - return opv->value == value; - } else if (const auto* opv = val.as()) { - return opv->value == static_cast(value); - } - } - return false; -} - template inline bool is_const_value(const Expr& e, ValueType value) { static_assert(std::is_integral::value, @@ -537,7 +521,7 @@ inline bool is_const_value(const Expr& e, ValueType value) { if (const ir::IntImm* i = e.as()) { return i->value == value; } else if (const ir::UIntImm* i = e.as()) { - return (value >= 0) && (i->value == (uint64_t)value); + return (value >= 0) && (i->value == static_cast(value)); } else if (const ir::FloatImm* i = e.as()) { return i->value == value; } else if (const ir::Cast* c = e.as()) { @@ -549,6 +533,15 @@ inline bool is_const_value(const Expr& e, ValueType value) { } } +inline bool is_const_int(const Expr& x, int64_t value) { + if (x.as() || x.as()) { + return is_const_value(x, value); + } else if (const auto* op = x.as()) { + return !op->value.as() && is_const_int(op->value, value); + } + return false; +} + inline bool is_no_op(const Stmt& stmt) { if (!stmt.defined()) return true; if (const auto* op = stmt.as()) { From f020bf17e102e63c2c9ed0002b5db99cb0578279 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Thu, 28 Feb 2019 16:32:57 +0300 Subject: [PATCH 04/10] Use vranges in IsSumCombiner, CanFactorZeroFromCombiner and OptimizeAndLiftNonzeronessConditions --- src/pass/zero_elimination.cc | 68 ++++++++++++------- src/pass/zero_elimination.h | 14 +++- .../unittest/test_pass_zero_elimination.py | 32 +++++++++ 3 files changed, 87 insertions(+), 27 deletions(-) diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index 352ca4400abc..f44d671e67bc 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -248,23 +248,24 @@ Array IterVarsFromMap(const Array& vars, const Map& vr } // Return true if this combiner is just a sum. -bool IsSumCombiner(const CommReducer& combiner) { +bool IsSumCombiner(const CommReducer& combiner, const Map& vranges) { if (combiner->result.size() != 1) { return false; } - if (!is_const_value(SuperSimplify(combiner->identity_element[0]), 0)) { + if (!is_const_value(SuperSimplify(combiner->identity_element[0], vranges), 0)) { return false; } - return is_const_value(SuperSimplify(combiner->result[0] - - (combiner->lhs[0] + combiner->rhs[0])), - 0); + Expr should_be_zero = + SuperSimplify(combiner->result[0] - (combiner->lhs[0] + combiner->rhs[0]), vranges); + return is_const_value(should_be_zero, 0); } // Return true if zero may be factored out of a reduction with this combiner. -bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index) { - if (!is_const_value(combiner->identity_element[value_index], 0)) { +bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index, + const Map& vranges) { + if (!is_const_value(SuperSimplify(combiner->identity_element[value_index], vranges), 0)) { return false; } @@ -272,7 +273,7 @@ bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index) { Expr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero}, {combiner->rhs[value_index], zero}}); - in = SuperSimplify(in); + in = SuperSimplify(in, vranges); return is_const_value(in, 0); } @@ -1545,13 +1546,16 @@ Expr ExtractNonTopReductions(const Expr& expr, } } -Expr OptimizeAndLiftNonzeronessConditionsImpl(const Expr& expr, const Array& axis) { +Expr OptimizeAndLiftNonzeronessConditionsImpl(const Expr& expr, + const Array& axis, + const Map& vranges) { Expr result; + Map combined_vranges = Merge(vranges, IterVarsToMap(axis)); if (const Reduce* red = expr.as()) { // TODO(sgrechanik-h): There are some other operations which behave like sum - bool is_sum = IsSumCombiner(red->combiner); - if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index)) { + bool is_sum = IsSumCombiner(red->combiner, vranges); + if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index, vranges)) { Expr new_red = expr; // Here we simplify the reduction @@ -1568,12 +1572,12 @@ Expr OptimizeAndLiftNonzeronessConditionsImpl(const Expr& expr, const Arraycombiner, source, red->axis, cond, red->value_index); - new_red = SimplifyReductionDomain(new_red, IterVarsToMap(axis)); + new_red = SimplifyReductionDomain(new_red, combined_vranges); red = new_red.as(); // If the reduction disappears completely then transform the result as a non-reduction if (!red) { - return OptimizeAndLiftNonzeronessConditionsImpl(new_red, axis); + return OptimizeAndLiftNonzeronessConditionsImpl(new_red, axis, vranges); } } @@ -1600,15 +1604,17 @@ Expr OptimizeAndLiftNonzeronessConditionsImpl(const Expr& expr, const Arraycombiner, new_source, red->axis, new_reduce_cond, red->value_index); new_reduce = ExtractAsTensorMaybe(new_reduce, new_outer_cond, - IterVarsToVars(axis), IterVarsToMap(axis)); + IterVarsToVars(axis), + combined_vranges); result = SelectElseZero(new_outer_cond, new_reduce); } else { - return SimplifyReductionDomain(expr, IterVarsToMap(axis)); + return SimplifyReductionDomain(expr, combined_vranges); } } else { auto nz = NonzeronessCondition(expr); Expr new_expr = ExtractAsTensorMaybe(nz.value, nz.cond, - IterVarsToVars(axis), IterVarsToMap(axis)); + IterVarsToVars(axis), + combined_vranges); result = SelectElseZero(nz.cond, new_expr); } @@ -1619,23 +1625,33 @@ Expr OptimizeAndLiftNonzeronessConditionsImpl(const Expr& expr, const Array vrange = IterVarsToMap(axis); - return SuperSimplify(ExtractReductions(result, IterVarsToVars(axis), vrange), - vrange); + return SuperSimplify(ExtractReductions(result, IterVarsToVars(axis), combined_vranges), + combined_vranges); } -Tensor OptimizeAndLiftNonzeronessConditions(const Tensor& tensor) { - return op::TransformBody(tensor, OptimizeAndLiftNonzeronessConditionsImpl); +Tensor OptimizeAndLiftNonzeronessConditions(const Tensor& tensor, const Map& vranges) { + auto transform_func = [&vranges](const Expr& expr, const Array& axis) { + return OptimizeAndLiftNonzeronessConditionsImpl(expr, axis, vranges); + }; + return op::TransformBody(tensor, transform_func); } TVM_REGISTER_API("ir_pass.IsSumCombiner") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = IsSumCombiner(args[0]); + if (args.size() >= 2) { + *ret = IsSumCombiner(args[0], args[1]); + } else { + *ret = IsSumCombiner(args[0]); + } }); TVM_REGISTER_API("ir_pass.CanFactorZeroFromCombiner") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = CanFactorZeroFromCombiner(args[0], args[1]); + if (args.size() >= 3) { + *ret = CanFactorZeroFromCombiner(args[0], args[1], args[2]); + } else { + *ret = CanFactorZeroFromCombiner(args[0], args[1]); + } }); TVM_REGISTER_API("ir_pass.LiftNonzeronessCondition") @@ -1705,7 +1721,11 @@ TVM_REGISTER_API("ir_pass.ExtractNonTopReductions") TVM_REGISTER_API("ir_pass.OptimizeAndLiftNonzeronessConditions") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = OptimizeAndLiftNonzeronessConditions(args[0]); + if (args.size() >= 2) { + *ret = OptimizeAndLiftNonzeronessConditions(args[0], args[1]); + } else { + *ret = OptimizeAndLiftNonzeronessConditions(args[0]); + } }); } // namespace ir diff --git a/src/pass/zero_elimination.h b/src/pass/zero_elimination.h index bcfba038cc02..8246390c7064 100644 --- a/src/pass/zero_elimination.h +++ b/src/pass/zero_elimination.h @@ -22,7 +22,8 @@ Expr CloneReduction(const Expr& expr); /*! * \brief Check if the given combiner represents summation. */ -EXPORT bool IsSumCombiner(const CommReducer& combiner); +EXPORT bool IsSumCombiner(const CommReducer& combiner, + const Map& vranges = Map()); /*! * \brief Check if zero may be factored out of a reduction with this combiner when it is in @@ -32,7 +33,8 @@ EXPORT bool IsSumCombiner(const CommReducer& combiner); * check that `(a, 0) combine (b, 0) = (c, 0)` for any a, b and some c. * Note that all combiners generated by autodiff have this property. */ -EXPORT bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index); +EXPORT bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index, + const Map& vranges = Map()); /*! * \brief Transform the expression into `c ? e : 0`, that is lift the condition of being @@ -231,8 +233,14 @@ EXPORT Expr ExtractNonTopReductions(const Expr& expr, * \brief Perform lifting of conditions of being possible to be non-zero together with * applying some transformations like simplifying the reduction domain. Works only with * this particular tensor's body, i.e. doesn't perform inlining. + * + * \param tensor The original tensor; + * \param vranges Optional map from free variables to their value ranges. + * \return An optimized tensor. */ -EXPORT Tensor OptimizeAndLiftNonzeronessConditions(const Tensor& tensor); +EXPORT Tensor OptimizeAndLiftNonzeronessConditions( + const Tensor& tensor, + const Map& vranges = Map()); } // namespace ir } // namespace tvm diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py index 9f8c110b7197..258a6dacaf9d 100644 --- a/tests/python/unittest/test_pass_zero_elimination.py +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -91,6 +91,13 @@ def _compute_body(*us): lambda t0, t1: (tvm.const(0, t0), tvm.const(0, t1))) xor_combiner = comm_reducer(lambda x, y: x ^ y, lambda t0: tvm.const(0, t0)) +m_param = tvm.var('m_param') +sum_or_prod_combiner = comm_reducer(lambda x, y: tvm.expr.Select(m_param < 0, x + y, x*y), + lambda t0: tvm.expr.Select(m_param < 0, + tvm.const(0, t0), tvm.const(1, t0))) +shifted_sum_combiner = comm_reducer(lambda x, y: x + y - m_param, + lambda t0: m_param) + def test_is_sum_combiner(): k = tvm.reduce_axis((0, 10), name="k") i = tvm.const(0, "int32") @@ -102,6 +109,11 @@ def test_is_sum_combiner(): assert not IsSumCombiner(sum_derivative_combiner((f, f), k)[0].combiner) assert not IsSumCombiner(prod_combiner(f, k).combiner) assert not IsSumCombiner(prod_derivative_combiner((f, f), k)[1].combiner) + assert not IsSumCombiner(sum_or_prod_combiner(f, k).combiner) + assert not IsSumCombiner(sum_or_prod_combiner(f, k).combiner, {m_param: tvm.Range(-5, 1)}) + assert IsSumCombiner(sum_or_prod_combiner(f, k).combiner, {m_param: tvm.Range(-5, -1)}) + assert not IsSumCombiner(shifted_sum_combiner(i, k).combiner) + assert IsSumCombiner(shifted_sum_combiner(i, k).combiner, {m_param: tvm.Range(0, 1)}) def test_can_factor_zero_from_combiner(): k = tvm.reduce_axis((0, 10), name="k") @@ -115,6 +127,13 @@ def test_can_factor_zero_from_combiner(): assert CanFactorZeroFromCombiner(prod_derivative_combiner((f, f), k)[0].combiner, 1) assert CanFactorZeroFromCombiner(sum_both_combiner((f, f), k)[0].combiner, 0) assert not CanFactorZeroFromCombiner(sum_both_combiner((f, f), k)[0].combiner, 1) + assert not CanFactorZeroFromCombiner(sum_or_prod_combiner(f, k).combiner, 0, + {m_param: tvm.Range(-5, 1)}) + assert CanFactorZeroFromCombiner(sum_or_prod_combiner(f, k).combiner, 0, + {m_param: tvm.Range(-5, -1)}) + assert not CanFactorZeroFromCombiner(shifted_sum_combiner(i, k).combiner, 0) + assert CanFactorZeroFromCombiner(shifted_sum_combiner(i, k).combiner, 0, + {m_param: tvm.Range(0, 1)}) def test_lift_nonzeroness_condition(): k = tvm.reduce_axis((0, 5), name="k") @@ -462,6 +481,19 @@ def test_optimize_and_lift_nonzeroness(): check_eq(B, R, [A]) assert estimate_performance(B) <= estimate_performance(R) + # Specifying ranges of parameters + B = compute((10, 10), lambda i, j: sum_or_prod_combiner((i == j)*A[i, k] + A[k, j]*(i == j), k)) + B = OptimizeAndLiftNonzeronessConditions(B, {m_param: tvm.Range(-5, -3)}) + R = lambda i, j: tvm.expr.Select(i == j, + tvm.sum(A[j, k] + A[k, j], k), + zero) + check_tensor_symeq(B, R) + + B = compute((10, 10), lambda i, j: tvm.sum(((i - k) <= m_param) * A[i, k], k)) + B = OptimizeAndLiftNonzeronessConditions(B, {m_param: tvm.Range(11, 20)}) + R = lambda i, j: tvm.sum(A[i, k], k) + check_tensor_symeq(B, R) + if __name__ == "__main__": test_is_sum_combiner() test_can_factor_zero_from_combiner() From 7f5ef4d57e9f217597db312c64e1930fc85d2e7e Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Thu, 28 Feb 2019 18:14:50 +0300 Subject: [PATCH 05/10] Fix the bug in tensor inlining --- src/pass/zero_elimination.cc | 20 ++++++++----------- .../unittest/test_pass_zero_elimination.py | 13 ++++++++++++ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index f44d671e67bc..263f33e46d99 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -278,6 +278,7 @@ bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index, return is_const_value(in, 0); } +// If expr is a Call node, perform inlining, otherwise do nothing Expr InlineThisCall(const Expr& expr) { if (const Call* op = expr.as()) { if (op->call_type == Call::CallType::Halide) { @@ -304,6 +305,7 @@ Tensor InlineTailCall(const Tensor& tensor) { return op::TransformBody(tensor, InlineThisCall); } +// Implements InlineTensors by trying to inline every Call of the given Expr class InlineTensorsMutator : public IRMutator { public: explicit InlineTensorsMutator(const Array& inlineable, bool inline_reductions = false) @@ -317,26 +319,20 @@ class InlineTensorsMutator : public IRMutator { if (op->call_type == Call::CallType::Halide) { const ComputeOpNode* op_comp = op->func.as(); if (inlineable_.empty() || inlineable_.count({op_comp, op->value_index})) { + // Inline only compute nodes that are not reductions (unless inline reductions is allowed) if (op_comp && (inline_reductions_ || !op_comp->body[0].as())) { - Array tensor_axes; - for (const auto& var : op_comp->axis) { - tensor_axes.push_back(var->var); - } - - Stmt inlined = Inline(Evaluate::make(e), op->func, tensor_axes, - op_comp->body[op->value_index]); - if (const ir::Evaluate* ev = inlined.as()) { - // If it is a reduction, clone it - return Mutate(ev->value); - } + // Inline this call and then try to perform further inlining + return Mutate(InlineThisCall(e)); } } } - return e; + // If we cannot inline this call, we should try to doinlining in its arguments + return IRMutator::Mutate_(op, e); } private: + // Tensors which are allowed to be inlined, represented as pairs (op_node, value_index) std::set> inlineable_; bool inline_reductions_; }; diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py index 258a6dacaf9d..527c5399bb86 100644 --- a/tests/python/unittest/test_pass_zero_elimination.py +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -194,6 +194,7 @@ def test_inline_tensors(): k = tvm.reduce_axis((0, 5), name="k") D = tvm.compute((10, 10), lambda i, j: tvm.sum(A[i, k], k)) E = tvm.compute((10, 10), lambda i, j: A[2, j] + C[i, 2] + D[i, j]) + F = tvm.compute((10, 10), lambda i, j: tvm.exp(A[i, j]) + B[i, A[i, j]]) R = InlineTensors(E) resbody = lambda i, j: 2 + j + i + 2 + i*2 + D[i, j] @@ -211,6 +212,18 @@ def test_inline_tensors(): resbody = lambda i, j: A[2, j] + (A[i, 2] + i*2) + D[i, j] check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + R = InlineTensors(F) + resbody = lambda i, j: tvm.exp(i + j) + i * (i + j) + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(F, [A]) + resbody = lambda i, j: tvm.exp(i + j) + B[i, (i + j)] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(F, [B]) + resbody = lambda i, j: tvm.exp(A[i, j]) + i * A[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + def test_solve_system_of_inequalities(): seed = random.randrange(sys.maxsize) print("\nseed: {}\n".format(seed)) From 8c8dfb05ac9359eaf6428bed4c34e9d99388eb05 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Mon, 4 Mar 2019 14:52:31 +0300 Subject: [PATCH 06/10] Fix treatment of if_then_else in LiftNonzeronessCondition --- src/pass/zero_elimination.cc | 308 ++++++++++-------- .../unittest/test_pass_zero_elimination.py | 9 + 2 files changed, 175 insertions(+), 142 deletions(-) diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index 263f33e46d99..e151b0df7583 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -360,17 +360,20 @@ struct NonzeronessConditionResult { } }; +// The implementation of NonzeronessCondition class NonzeronessConditionFunctor : public ExprFunctor { public: NonzeronessConditionResult NonzeronessCondition(const Expr& e) { if (e.type().is_bool()) { + // Boolean expressions are non-zero whenever they are true themselves return {e, const_true()}; } else { return VisitExpr(e, e); } } + // Most of the cases are implemented using helpers below result_type VisitExpr_(const Variable*, const Expr& e) final { return Default_(e); } result_type VisitExpr_(const IntImm* op, const Expr& e) final { return Const_(op, e); } result_type VisitExpr_(const UIntImm* op, const Expr& e) final { return Const_(op, e); } @@ -385,32 +388,62 @@ class NonzeronessConditionFunctor result_type VisitExpr_(const Max* op, const Expr& e) final { return BinOpAddLike_(op, e); } result_type VisitExpr_(const Cast* op, const Expr& e) final { - if (op->value.type().is_bool()) { - return {op->value, make_const(e.type(), 1)}; - } else { - auto nz_a = NonzeronessCondition(op->value); + auto nz_a = NonzeronessCondition(op->value); - if (nz_a.value.same_as(op->value)) { - return {nz_a.cond, e}; - } else { - return {nz_a.cond, Cast::make(op->type, nz_a.value)}; - } + if (nz_a.value.same_as(op->value)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, Cast::make(op->type, nz_a.value)}; } } result_type VisitExpr_(const Select* op, const Expr& e) final { - return SelectLike_(e, op->condition, op->true_value, op->false_value, Select::make); + Expr cond = op->condition, true_val = op->true_value, false_val = op->false_value; + auto nz_a = NonzeronessCondition(true_val); + auto nz_b = NonzeronessCondition(false_val); + + // If the false part is zero, we can get rid of the select + if (is_const_value(nz_b.value, 0)) { + Expr new_cond = SuperSimplify(nz_a.cond && cond); + return {new_cond, nz_a.value}; + } + + // If the true part is zero, we can also get rid of the select + if (is_const_value(nz_a.value, 0)) { + Expr new_cond = SuperSimplify(nz_b.cond && !cond); + return {new_cond, nz_b.value}; + } + + // Otherwise we retain the select and combine the conditions into this + Expr new_cond = SuperSimplify((cond && nz_a.cond) || (!cond && nz_b.cond)); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, e}; + } else { + return {new_cond, Select::make(cond, nz_a.value, nz_b.value)}; + } } result_type VisitExpr_(const Call* op, const Expr& e) final { if (op->name == intrinsic::tvm_if_then_else) { - return SelectLike_(e, op->args[0], op->args[1], op->args[2], if_then_else); + Expr cond = op->args[0], true_val = op->args[1], false_val = op->args[2]; + auto nz_a = NonzeronessCondition(true_val); + auto nz_b = NonzeronessCondition(false_val); + + // We don't have as much freedom here as in the select case + // since the `if` must be preserved in any case + Expr new_cond = SuperSimplify((cond && nz_a.cond) || (!cond && nz_b.cond)); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, e}; + } else { + return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)}; + } } else { return Default_(e); } } NonzeronessConditionResult Default_(const Expr& e) { + // This is always correct, so it's the default return {const_true(), e}; } @@ -423,43 +456,27 @@ class NonzeronessConditionFunctor } } - template - NonzeronessConditionResult SelectLike_(const Expr& e, const Expr& cond, const Expr& true_val, - const Expr& false_val, make_select_type make_select) { - auto nz_a = NonzeronessCondition(true_val); - auto nz_b = NonzeronessCondition(false_val); - - if (is_const_value(nz_b.value, 0)) { - Expr new_cond = SuperSimplify(nz_a.cond && cond); - return {new_cond, nz_a.value}; - } - - if (is_const_value(nz_a.value, 0)) { - Expr new_cond = SuperSimplify(nz_b.cond && !cond); - return {new_cond, nz_b.value}; - } - - Expr new_cond = SuperSimplify((cond && nz_a.cond) || (!cond && nz_b.cond)); - if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { - return {new_cond, e}; - } else { - return {new_cond, make_select(cond, nz_a.value, nz_b.value)}; - } - } - template NonzeronessConditionResult BinOpAddLike_(const TNode* op, const Expr& e) { auto nz_a = NonzeronessCondition(op->a); auto nz_b = NonzeronessCondition(op->b); + // For addition and similar ops the result may be nonzero if either of the arguments is + // nonzero, so we combine the conditions with Or. + if (Equal(nz_a.cond, nz_b.cond)) { + // If the conditions are the same, we don't need Or if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { return {nz_a.cond, e}; } else { return {nz_a.cond, TNode::make(nz_a.value, nz_b.value)}; } } else { + // Otherwise use Or Expr new_cond = SuperSimplify(nz_a.cond || nz_b.cond); + // A little optimization: if the combined condition is the same as one of the inner + // conditions, we don't need to guard the inner value with a select, otherwise + // we create a select in the `to_expr` call. Expr new_a = Equal(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr(); Expr new_b = Equal(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr(); Expr new_expr = TNode::make(new_a, new_b); @@ -472,6 +489,9 @@ class NonzeronessConditionFunctor auto nz_a = NonzeronessCondition(op->a); auto nz_b = NonzeronessCondition(op->b); + // For multiplication and similar ops the result may be nonzero if + // both the arguments are nonzero, so we combine with And. + Expr new_cond = SuperSimplify(nz_a.cond && nz_b.cond); if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { @@ -485,6 +505,8 @@ class NonzeronessConditionFunctor NonzeronessConditionResult BinOpDivLike_(const TNode* op, const Expr& e) { auto nz_a = NonzeronessCondition(op->a); + // For Div we simply use the condition of the numerator. + if (nz_a.value.same_as(op->a)) { return {nz_a.cond, e}; } else { @@ -493,6 +515,8 @@ class NonzeronessConditionFunctor } }; +// Transform expr into a pair (condition, new_expr) such that the old expr is equivalent to +// `select(condition, new_expr, 0)`. The pair is represented as a struct for clarity. NonzeronessConditionResult NonzeronessCondition(const Expr& expr) { return NonzeronessConditionFunctor().NonzeronessCondition(expr); } @@ -627,6 +651,110 @@ FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const Expr& e) { } +class RemoveRedundantInequalitiesMutator : public IRMutator { + public: + explicit RemoveRedundantInequalitiesMutator(Array known) { + for (const Expr& cond : known) { + known_.push_back(SuperSimplify(cond)); + } + } + + virtual Expr Mutate_(const Select* op, const Expr& e) { + bool has_side_effect = HasSideEffect(e); + Expr new_cond = SuperSimplify(Mutate(op->condition)); + if (is_one(new_cond) && !has_side_effect) { + return Mutate(op->true_value); + } else if (is_zero(new_cond) && !has_side_effect) { + return Mutate(op->false_value); + } else { + Array new_known = known_; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return Select::make(new_cond, new_mutator.Mutate(op->true_value), Mutate(op->false_value)); + } + } + + virtual Expr Mutate_(const Call* op, const Expr& e) { + if (op->name == intrinsic::tvm_if_then_else) { + Expr new_cond = SuperSimplify(Mutate(op->args[0])); + if (is_one(new_cond)) { + return Mutate(op->args[1]); + } else if (is_zero(new_cond)) { + return Mutate(op->args[2]); + } else { + Array new_known = known_; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return if_then_else(new_cond, new_mutator.Mutate(op->args[1]), Mutate(op->args[2])); + } + } else { + return IRMutator::Mutate_(op, e); + } + } + + virtual Expr Mutate_(const Reduce* op, const Expr& e) { + Array known_with_axes = known_; + for (const Expr& axis_cond : IterVarsToInequalities(op->axis)) { + known_with_axes.push_back(axis_cond); + } + RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes); + + Expr new_cond = mutator_with_axes.Mutate(op->condition); + + Array new_known = known_with_axes; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + + Array new_source; + for (const Expr& src : op->source) { + new_source.push_back(new_mutator.Mutate(src)); + } + + return Reduce::make(op->combiner, new_source, op->axis, new_cond, op->value_index); + } + + virtual Expr Mutate_(const EQ* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const NE* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const LT* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const LE* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const GT* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const GE* op, const Expr& e) { return MutateAtomic_(e); } + + virtual Expr Mutate_(const And* op, const Expr& e) { + return Mutate(op->a) && Mutate(op->b); + } + + private: + Expr MutateAtomic_(const Expr& e) { + Expr simplified = SuperSimplify(e); + for (const Expr& other : known_) { + if (Equal(simplified, other)) { + return const_true(); + } + } + return simplified; + } + + Array known_; +}; + +// Propagate information from conditions and remove redundant inequalities +// TODO(sgrechanik-h): This should be merged into standard simplifiers +Expr RemoveRedundantInequalities(const Expr& expr, const Array& known) { + return RemoveRedundantInequalitiesMutator(known).Mutate(expr); +} + + struct EliminateDivModResult { Expr expr; Map substitution; @@ -1246,7 +1374,9 @@ Expr ExtractAsTensorMaybe(const Expr& e, const Expr& cond, // TODO(sgrechanik-h): We don't use divmod elimination here because of some performance problems auto res = SimplifyDomain(cond, outer_axis, vranges, false); - Expr new_expr = SuperSimplify(Substitute(e, res.old_to_new), vranges); + Expr new_expr = SuperSimplify(Substitute(e, res.old_to_new), Merge(vranges, res.ranges)); + // This is mostly done to simplify if_then_else which is not known by the Halide simplifier + new_expr = RemoveRedundantInequalities(new_expr, res.conditions); // Keep only those variables of the new axis which are used in the new_expr { @@ -1260,9 +1390,6 @@ Expr ExtractAsTensorMaybe(const Expr& e, const Expr& cond, res.axis = std::move(used_res_axis); } - // Use the new axis to simplify the new expr, removing redundant inequalities - new_expr = SuperSimplify(new_expr, res.ranges); - // If the expression does not use vars then it is probably better to keep it inlined if (res.axis.empty()) { return new_expr; @@ -1298,109 +1425,6 @@ Expr ExtractAsTensorMaybe(const Expr& e, const Expr& cond, } -class RemoveRedundantInequalitiesMutator : public IRMutator { - public: - explicit RemoveRedundantInequalitiesMutator(Array known) { - for (const Expr& cond : known) { - known_.push_back(SuperSimplify(cond)); - } - } - - virtual Expr Mutate_(const Select* op, const Expr& e) { - bool has_side_effect = HasSideEffect(e); - Expr new_cond = SuperSimplify(Mutate(op->condition)); - if (is_one(new_cond) && !has_side_effect) { - return Mutate(op->true_value); - } else if (is_zero(new_cond) && !has_side_effect) { - return Mutate(op->false_value); - } else { - Array new_known = known_; - for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { - new_known.push_back(atomic); - } - RemoveRedundantInequalitiesMutator new_mutator(new_known); - // Note that we mutate only the true value with the new mutator - // TODO(sgrechanik-h): Update known conditions for the false value as well - return Select::make(new_cond, new_mutator.Mutate(op->true_value), Mutate(op->false_value)); - } - } - - virtual Expr Mutate_(const Call* op, const Expr& e) { - if (op->name == intrinsic::tvm_if_then_else) { - Expr new_cond = SuperSimplify(Mutate(op->args[0])); - if (is_one(new_cond)) { - return Mutate(op->args[1]); - } else if (is_zero(new_cond)) { - return Mutate(op->args[2]); - } else { - Array new_known = known_; - for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { - new_known.push_back(atomic); - } - RemoveRedundantInequalitiesMutator new_mutator(new_known); - // Note that we mutate only the true value with the new mutator - // TODO(sgrechanik-h): Update known conditions for the false value as well - return if_then_else(new_cond, new_mutator.Mutate(op->args[1]), Mutate(op->args[2])); - } - } else { - return IRMutator::Mutate_(op, e); - } - } - - virtual Expr Mutate_(const Reduce* op, const Expr& e) { - Array known_with_axes = known_; - for (const Expr& axis_cond : IterVarsToInequalities(op->axis)) { - known_with_axes.push_back(axis_cond); - } - RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes); - - Expr new_cond = mutator_with_axes.Mutate(op->condition); - - Array new_known = known_with_axes; - for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { - new_known.push_back(atomic); - } - RemoveRedundantInequalitiesMutator new_mutator(new_known); - - Array new_source; - for (const Expr& src : op->source) { - new_source.push_back(new_mutator.Mutate(src)); - } - - return Reduce::make(op->combiner, new_source, op->axis, new_cond, op->value_index); - } - - virtual Expr Mutate_(const EQ* op, const Expr& e) { return MutateAtomic_(e); } - virtual Expr Mutate_(const NE* op, const Expr& e) { return MutateAtomic_(e); } - virtual Expr Mutate_(const LT* op, const Expr& e) { return MutateAtomic_(e); } - virtual Expr Mutate_(const LE* op, const Expr& e) { return MutateAtomic_(e); } - virtual Expr Mutate_(const GT* op, const Expr& e) { return MutateAtomic_(e); } - virtual Expr Mutate_(const GE* op, const Expr& e) { return MutateAtomic_(e); } - - virtual Expr Mutate_(const And* op, const Expr& e) { - return Mutate(op->a) && Mutate(op->b); - } - - private: - Expr MutateAtomic_(const Expr& e) { - Expr simplified = SuperSimplify(e); - for (const Expr& other : known_) { - if (Equal(simplified, other)) { - return const_true(); - } - } - return simplified; - } - - Array known_; -}; - -// Propagate information from conditions and remove redundant inequalities -// TODO(sgrechanik-h): This should be merged into standard simplifiers -Expr RemoveRedundantInequalities(const Expr& expr, const Array& known) { - return RemoveRedundantInequalitiesMutator(known).Mutate(expr); -} - // Extract from cond an implication of cond not containing vars std::pair ImplicationNotContainingVars( const Expr& cond, const std::unordered_set& vars) { diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py index 527c5399bb86..c315d10eb93b 100644 --- a/tests/python/unittest/test_pass_zero_elimination.py +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -180,6 +180,15 @@ def _check_symeq(expr1, expr2): _check_symeq(tvm.min(tvm.expr.Cast('int32', k < n)*l, tvm.expr.Select(k >= n, 0, 1)), tvm.expr.Select(k < n, tvm.min(l, 1), 0)) + expr1 = tvm.if_then_else(k < n, + tvm.expr.Select(tvm.expr.EQ(k, l), A[k], 0.0), + tvm.expr.Select(l < n, A[l], 0.0)) + expr2 = tvm.expr.Select(tvm.any(tvm.all(k < n, tvm.expr.EQ(k, l)), + tvm.all(k >= n, l < n)), + tvm.if_then_else(k < n, A[k], A[l]), + 0.0) + check_symeq(LiftNonzeronessCondition(expr1), expr2) + def test_inline_tail_call(): A = tvm.compute((10, 10), lambda i, j: i + j*j) B = tvm.compute((5, 6), lambda k, l: A[k + l, k + 1]) From f697063088e2860f17ae1edafdeb64db21519589 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Mon, 4 Mar 2019 18:11:18 +0300 Subject: [PATCH 07/10] More comments; Fixed the behaviour of divmod elimination for equal exprs --- python/tvm/testing.py | 36 +++++--- src/pass/zero_elimination.cc | 85 +++++++++++++++---- .../unittest/test_pass_zero_elimination.py | 3 +- 3 files changed, 95 insertions(+), 29 deletions(-) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index afdca6a19720..2db7c164b1e8 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -197,7 +197,7 @@ def __le__(self, other): self.memory <= other.memory -def estimate_performance(s, param_values=None, processed_ops=None): +def estimate_performance(s, param_values=None, _processed_ops=None): """Statically estimate performance of statements, expressions and tensors. Note that the estimate is very rough, it mustn't be used to predict future performance, its only purpose is to detect possible performance regressions. @@ -209,11 +209,25 @@ def estimate_performance(s, param_values=None, processed_ops=None): of any of the above. param_values : Dict[tvm.expr.Var, int], optional - Values for parameters (free variables). + Values for parameters (free variables), see the example. + + _processed_ops, optional + A dict mapping already processed operations to the corresponding estimations. + This parameter is used internally. Returns ------- estimate : PerformanceEstimate + + Example + ------- + .. code-block:: python + + m = tvm.var('m') + X = tvm.placeholder((10, m), name='X') + W = tvm.placeholder((m + 5, m), name='W') + A = topi.nn.dense(X, W) + tvm.testing.estimate_performance(A, param_values={m: 5}) """ from tvm import stmt from tvm import expr @@ -221,17 +235,17 @@ def estimate_performance(s, param_values=None, processed_ops=None): if param_values is None: param_values = {} - if processed_ops is None: - processed_ops = {} - res = estimate_performance(s, param_values=param_values, processed_ops=processed_ops) - for op_est in processed_ops.values(): + if _processed_ops is None: + _processed_ops = {} + res = estimate_performance(s, param_values=param_values, _processed_ops=_processed_ops) + for op_est in _processed_ops.values(): res += op_est return res - def est(expression, param_values=param_values, processed_ops=processed_ops): + def est(expression, param_values=param_values, _processed_ops=_processed_ops): return estimate_performance(expression, param_values=param_values, - processed_ops=processed_ops) + _processed_ops=_processed_ops) def _eval(expression, param_values=param_values): return tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expression, param_values)).value @@ -250,7 +264,7 @@ def _prod(elems): for item in s: res += est(item) return res - elif s in processed_ops: + elif s in _processed_ops: return PerformanceEstimate() elif isinstance(s, stmt.Allocate): mem = _prod([_eval(e) for e in s.extents]) @@ -286,7 +300,7 @@ def _prod(elems): for a in s.args: res += est(a) if s.call_type == expr.Call.Halide: - # The estimate is added to processed_ops, we don't need the result here + # The estimate is added to _processed_ops, we don't need the result here est(s.func) elif s.name == "tvm_if_then_else": pass @@ -324,7 +338,7 @@ def _prod(elems): res += est(b) res.iterations = max(1, res.iterations) res = res.times(iterations) + PerformanceEstimate(memory=iterations*len(s.body)) - processed_ops[s] = res + _processed_ops[s] = res return PerformanceEstimate() raise ValueError("Don't know how to estimate performance of {} of type {}" diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index e151b0df7583..4e90a7511440 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -317,17 +317,19 @@ class InlineTensorsMutator : public IRMutator { Expr Mutate_(const Call* op, const Expr& e) { if (op->call_type == Call::CallType::Halide) { - const ComputeOpNode* op_comp = op->func.as(); - if (inlineable_.empty() || inlineable_.count({op_comp, op->value_index})) { - // Inline only compute nodes that are not reductions (unless inline reductions is allowed) - if (op_comp && (inline_reductions_ || !op_comp->body[0].as())) { - // Inline this call and then try to perform further inlining - return Mutate(InlineThisCall(e)); + if (const ComputeOpNode* op_comp = op->func.as()) { + // Inline only if the array of inlineable tensors is empty or contains this tensor + if (inlineable_.empty() || inlineable_.count({op_comp, op->value_index})) { + // Inline only compute nodes that are not reductions (unless inline reductions is allowed) + if (inline_reductions_ || !op_comp->body[0].as()) { + // Inline this call and then try to perform further inlining + return Mutate(InlineThisCall(e)); + } } } } - // If we cannot inline this call, we should try to doinlining in its arguments + // If we cannot inline this call, we should try to do inlining in its arguments return IRMutator::Mutate_(op, e); } @@ -565,13 +567,16 @@ struct FactorOutAtomicFormulasResult { } }; +// The implementation of FactorOutAtomicFormulas class FactorOutAtomicFormulasFunctor : public ExprFunctor { public: result_type Atomic_(const Expr& e) { + // For atomic expressions the result is the expr itself with True as the residual return {{e}, make_const(e.type(), 1)}; } + // This is basically the list of expression kinds that are considered atomic result_type VisitExpr_(const Variable*, const Expr& e) final { return Atomic_(e); } result_type VisitExpr_(const Call*, const Expr& e) final { return Atomic_(e); } result_type VisitExpr_(const IntImm*, const Expr& e) final { return Atomic_(e); } @@ -587,6 +592,7 @@ class FactorOutAtomicFormulasFunctor auto res_a = VisitExpr(op->a, op->a); auto res_b = VisitExpr(op->b, op->b); + // For the And case we return the union of the sets of atomic formulas std::vector res; res.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); std::set_union(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), @@ -594,6 +600,7 @@ class FactorOutAtomicFormulasFunctor std::back_inserter(res), ExprLess()); + // And the residuals are combined with && return {res, res_a.rest && res_b.rest}; } @@ -601,6 +608,7 @@ class FactorOutAtomicFormulasFunctor auto res_a = VisitExpr(op->a, op->a); auto res_b = VisitExpr(op->b, op->b); + // For multiplication we do the same thing as for And std::vector res; res.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); std::set_union(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), @@ -615,6 +623,7 @@ class FactorOutAtomicFormulasFunctor auto res_a = VisitExpr(op->a, op->a); auto res_b = VisitExpr(op->b, op->b); + // For the Or case we intersect the sets of atomic formulas std::vector res; res.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); std::set_intersection(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), @@ -622,6 +631,9 @@ class FactorOutAtomicFormulasFunctor std::back_inserter(res), ExprLess()); + // Computing the residual is more complex: we have to compute the sets of atomic formulas + // which are left behind, and then combine them with the residuals into the new residual. + std::vector new_cond_a; new_cond_a.reserve(res_a.atomic_formulas.size() - res.size()); std::set_difference(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), @@ -645,7 +657,9 @@ class FactorOutAtomicFormulasFunctor } }; -// Transform the given formula into an array of atomic formulas and a non-atomic residual. +// Transform the given formula into a conjunction of atomic formulas (represented as an array) +// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b, +// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level. FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const Expr& e) { return FactorOutAtomicFormulasFunctor().VisitExpr(e, e); } @@ -776,16 +790,18 @@ class EliminateDivModMutator : public IRMutator { virtual Expr Mutate_(const Div* op, const Expr& e) { const IntImm* imm = op->b.as(); if (imm && imm->value > 0) { - auto it = expr_to_vars_.find({op->a.get(), imm->value}); + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find({op->a, imm->value}); if (it != expr_to_vars_.end()) { return it->second.first; } + // Otherwise recursively mutate the left hand side, and create new variables Expr mutated_a = Mutate(op->a); if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { return var_pair_opt.value().first; } else { - return mutated_a / Mutate(op->b); + return mutated_a / op->b; } } @@ -795,16 +811,18 @@ class EliminateDivModMutator : public IRMutator { virtual Expr Mutate_(const Mod* op, const Expr& e) { const IntImm* imm = op->b.as(); if (imm && imm->value > 0) { - auto it = expr_to_vars_.find({op->a.get(), imm->value}); + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find({op->a, imm->value}); if (it != expr_to_vars_.end()) { return it->second.second; } + // Otherwise recursively mutate the left hand side, and create new variables Expr mutated_a = Mutate(op->a); if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { return var_pair_opt.value().second; } else { - return mutated_a % Mutate(op->b); + return mutated_a % op->b; } } @@ -815,23 +833,35 @@ class EliminateDivModMutator : public IRMutator { dmlc::optional> AddNewVarPair(const Expr& e, const Expr& mut, int64_t val) { using tresult = dmlc::optional>; + // Try to find the variables using the mutated expressions + if (!e.same_as(mut)) { + auto it = expr_to_vars_.find({mut, val}); + if (it != expr_to_vars_.end()) { + return tresult(it->second); + } + } + Expr val_e = make_const(e.type(), val); idx_ += 1; + // Convert `ranges` to IntSets std::unordered_map var_intsets; for (const auto& p : ranges) { var_intsets[p.first.get()] = IntSet::range(p.second); } + // Infer ranges for the expressions we want to replace with variables Range div_range = EvalSet(mut / val_e, var_intsets).cover_range(Range()); Range mod_range = EvalSet(mut % val_e, var_intsets).cover_range(Range()); + // We don't want to add unbounded variables if (!div_range.get() || !mod_range.get()) { LOG(WARNING) << "EliminateDivMod: won't eliminate div or mod of expr " << e << " because its bounds cannot be inferred"; return tresult(); } + // Create new variables for the expressions auto div = Var("div" + std::to_string(idx_), e.type()); auto mod = Var("mod" + std::to_string(idx_), e.type()); @@ -844,26 +874,49 @@ class EliminateDivModMutator : public IRMutator { ranges.Set(div, div_range); ranges.Set(mod, mod_range); + // This additional condition works as a definition for the new variables conditions.push_back(mut == div*val_e + mod); if (!CanProve(mod_range->extent <= val_e)) { + // Since we use the C/C++ definition of mod, there may be multiple values of `mod` + // satisfying the added condition if the expr `e` may change its sign, so we + // have to add another condition. LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod of expr " << e << " (probably it may change its sign)"; - // We cannot prove that mod is unique, so add additional condition conditions.push_back(Select::make(e >= 0, mod >= 0, mod <= 0)); } auto p = std::make_pair(div, mod); - expr_to_vars_[{e.get(), val}] = p; + expr_to_vars_[{e, val}] = p; + if (!e.same_as(mut)) { + expr_to_vars_[{mut, val}] = p; + } return tresult(p); } + // A custom comparison function for pairs of exprs and numbers. Compares exprs deeply. + struct Compare_ { + bool operator()(const std::pair& p1, const std::pair& p2) { + if (p1.second < p2.second) { + return true; + } else if (p1.second == p2.second) { + return Compare(p1.first, p2.first) < 0; + } else { + return false; + } + } + }; + + // A counter for naming new variables int idx_{0}; - std::map, std::pair> + // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod) + // such that `div = e / n` and `mod = e % n` + std::map, std::pair, Compare_> expr_to_vars_; }; -// replace every subexpr of the form e/const and e % const with a new variable +// Replace every subexpr of the form e/const and e % const with a new variable. +// Syntactically equal expressions will be mapped to the same variable. EliminateDivModResult EliminateDivMod(const Expr& expr, Map ranges) { EliminateDivModResult res; EliminateDivModMutator mutator(ranges); diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py index c315d10eb93b..cba13746315e 100644 --- a/tests/python/unittest/test_pass_zero_elimination.py +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -380,8 +380,7 @@ def _check(cond, axis, volume, vranges={}): k = tvm.reduce_axis((0, 10), name="k") l = tvm.reduce_axis((0, 10), name="l") - # TODO: This is not fully optimized because we don't have a solver - _check(tvm.all((l + k)%3 <= 1, (l + k)/3 <= 2), [l, k], 144) + _check(tvm.all((l + k)%3 <= 1, (l + k)/3 <= 2), [l, k], 48) def test_extract_as_tensor_maybe(): def _check(shape, fcompute, volume=None, vranges={}): From ea1c18476bb00dac9a4c19ed94173614854bd681 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Wed, 13 Mar 2019 17:30:14 +0300 Subject: [PATCH 08/10] Reverse the resulting axis array in SimplifyDomain --- src/pass/zero_elimination.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index 4e90a7511440..559b6eb1bcd7 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -1394,6 +1394,9 @@ DomainSimplificationResult SimplifyDomain(const Expr& cond, res.conditions.push_back(SuperSimplify(Substitute(old_cond, res.old_to_new), vranges)); } + // Reverse the axis so that it matches the order of the original variables + res.axis = Array(res.axis.rbegin(), res.axis.rend()); + return res; } From 58ba0fb8eadef91d0054b15a139b99537d1f37b4 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Wed, 13 Mar 2019 19:17:11 +0300 Subject: [PATCH 09/10] Don't extract tensor calls in ExtractAsTensorMaybe --- src/pass/zero_elimination.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index 559b6eb1bcd7..fd987a9f6b52 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -1448,9 +1448,18 @@ Expr ExtractAsTensorMaybe(const Expr& e, const Expr& cond, // If the expression does not use vars then it is probably better to keep it inlined if (res.axis.empty()) { + // We can return the new_expr here instead of the old e because it doesn't use variables + // otherwise we would need to replace the new vars or create a let-expression return new_expr; } + // If it's already a call to a tensor then extracting it will probably be useless + if (const Call* call = new_expr.as()) { + if (call->call_type == Call::CallType::Halide) { + return e; + } + } + // Compute volumes before and after Expr old_volume = make_const(Int(64), 1); for (const Var& var : outer_axis) { From cd1375e21ee87dfdabd5adcb50d187b03011f47c Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Wed, 13 Mar 2019 19:28:43 +0300 Subject: [PATCH 10/10] Fix lint errors --- src/pass/zero_elimination.cc | 2 ++ src/pass/zero_elimination.h | 1 + 2 files changed, 3 insertions(+) diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index fd987a9f6b52..775476d61fae 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include "arithmetic/ModulusRemainder.h" #include "../op/op_util.h" diff --git a/src/pass/zero_elimination.h b/src/pass/zero_elimination.h index 8246390c7064..600b3cb4162f 100644 --- a/src/pass/zero_elimination.h +++ b/src/pass/zero_elimination.h @@ -10,6 +10,7 @@ #include #include +#include namespace tvm { namespace ir {