Skip to content

Commit

Permalink
[TIR][Arith] Additional Simplifications Inside Conditionals (#11524)
Browse files Browse the repository at this point in the history
* [TIR][Arith] Use equality constraints in analyzer

Previously, constraints with inequalities were recognized and used for
simplifications by `ConstIntBoundAnalyzer` and `ModularSetAnalyzer`,
but constraints with equalities were not.  This adds equality-based
constraints.  (e.g. Inside the then-case of `if i==5`, the value of
`i` is known to be 5.)

* [TIR][Arith] RewriteSimplifier, apply literal constraints

Previously, constraints were only checked within a `tir.likely`
annotation.  After this change, constraints are used for
simplification of all boolean expressions.  (e.g. Within a conditional
`if i==n`, the expression `(i==n) and (j==m)` can be simplified to
`j==m`.)

* [TIR][Arith] Do not apply literal constraints to BufferLoad

If a literal constraint relies on the contents of a buffer, the
constraint may not be assumed to hold.  This prevents the incorrect
rewriting of `A[i]==n` to true within a `if A[i]==n` conditional, as
the value of `A[i]` may have changed.

* [TIR][Arith] Use each independent constraints in RewriteSimplifier

Inside a constraint `if i==n and j==m`, both `i==n` and `j==m` may be
replaced with true, even in separate expressions.

This commit uses a new internal utility function
`tvm::arith::ExtractConstraints`, which breaks up a boolean expression
into a list of true statements.  This may be used to reduce
duplication elsewhere, such as `const_int_bound.cc` and
`iter_affine_map.cc`.

* [TIR][Arith] Check for negation of literal constraints

When inside a conditional of `i!=n`, in addition to the previous
replacement of `i!=n` with true, we can also replace `i==n` with
false.

* [TIR][Arith] Added unittests for new simplifications

* Fix lint error

* Fixed handling of negation of non-boolean types

* Removed extra asterisk
  • Loading branch information
Lunderberg committed Jun 2, 2022
1 parent 03eefe0 commit c78539c
Show file tree
Hide file tree
Showing 7 changed files with 398 additions and 14 deletions.
3 changes: 3 additions & 0 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,9 @@ class ConstIntBoundAnalyzer::Impl
if ((x < c).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value - 1))};
}
if ((x == c).Match(cond) || (c == x).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, c.Eval()->value))};
}
if ((x && y).Match(cond)) {
auto ret1 = DetectBoundInfo(x.Eval());
auto ret2 = DetectBoundInfo(y.Eval());
Expand Down
55 changes: 55 additions & 0 deletions src/arith/constraint_extract.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/arith/constraint_extract.cc
*/

#include "constraint_extract.h"

#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>

#include "pattern_match.h"

namespace tvm {
namespace arith {

void CollectConstraints(const PrimExpr& expr, Analyzer* analyzer, std::vector<PrimExpr>* collect) {
collect->push_back(expr);

PVar<PrimExpr> x, y;
if ((x && y).Match(expr)) {
CollectConstraints(x.Eval(), analyzer, collect);
CollectConstraints(y.Eval(), analyzer, collect);
} else if ((!(x || y)).Match(expr)) {
CollectConstraints(analyzer->rewrite_simplify(tir::Not(x.Eval())), analyzer, collect);
CollectConstraints(analyzer->rewrite_simplify(tir::Not(y.Eval())), analyzer, collect);
}
}

std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr) {
std::vector<PrimExpr> out;
Analyzer analyzer;
CollectConstraints(expr, &analyzer, &out);
return out;
}

} // namespace arith
} // namespace tvm
58 changes: 58 additions & 0 deletions src/arith/constraint_extract.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file contraint_extract.h
*
* \brief Centralized location for extraction of constraints from a boolean expression.
*/

#ifndef TVM_ARITH_CONSTRAINT_EXTRACT_H_
#define TVM_ARITH_CONSTRAINT_EXTRACT_H_

#include <tvm/tir/expr.h>

#include <vector>

namespace tvm {
namespace arith {

/* \brief Returns constraints that are true if the expression is true.
*
* Utility to break up a boolean expression into independent
* constraints.
*
* Example: `i==5 && j==3` => `[i==5 && j==3, i==5, j==3]`
* Example: `i==5 || j==3` => `[i==5 || j==3]`
* Example: `!(i>5 || j==3)` => `[!(i==5 || j==3), i<=5, j!=3]`
*
* Intended for use in bounds analysis or simplification within a
* conditional, or identifying independent conditionals that may be
* hoisted.
*
* \param expr The expression to be analyzers
*
* \returns A vector of independent constraints
*/
std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr);

} // namespace arith
} // namespace tvm

#endif // TVM_ARITH_CONSTRAINT_EXTRACT_H_
4 changes: 4 additions & 0 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
Entry entry(coeff.Eval()->value, base.Eval()->value);
return UpdateByIntersect(var.Eval(), entry);
}
if ((var == base).Match(constraint) || (base == var).Match(constraint)) {
Entry entry(1, base.Eval()->value);
return UpdateByIntersect(var.Eval(), entry);
}
return nullptr;
}

Expand Down
50 changes: 43 additions & 7 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include "../target/datatype/registry.h"
#include "const_fold.h"
#include "constraint_extract.h"
#include "pattern_match.h"

namespace tvm {
Expand Down Expand Up @@ -228,7 +229,24 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
size_t old_literal_size = literal_constraints_.size();
// we will compare the already simplified result with the constraint,
// so simplify the constarint as well
literal_constraints_.push_back(operator()(constraint));
PrimExpr new_constraint = operator()(constraint);
for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) {
if (SideEffect(subconstraint) <= CallEffectKind::kPure) {
literal_constraints_.push_back(subconstraint);
// We could apply this during TryMatchLiteralConstraint, but
// that would require performing a rewrite of each expression
// being checked. This way, we only apply a rewrite for each
// constraint being applied.
PrimExpr negation;
if (subconstraint.dtype().is_bool()) {
negation = Not(subconstraint);
} else {
negation = subconstraint == make_zero(subconstraint.dtype());
}
negation = operator()(negation);
literal_constraints_.push_back(Not(negation));
}
}
size_t new_literal_size = literal_constraints_.size();
auto frecover = [old_literal_size, new_literal_size, this]() {
ICHECK_EQ(literal_constraints_.size(), new_literal_size);
Expand Down Expand Up @@ -1291,11 +1309,27 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) {
return ret;
}

Optional<PrimExpr> RewriteSimplifier::Impl::TryMatchLiteralConstraint(const PrimExpr& expr) const {
PrimExpr negation = Not(expr);

ExprDeepEqual expr_equal;
for (const auto& constraint : literal_constraints_) {
if (expr_equal(constraint, expr)) {
return make_const(expr->dtype, true);
}
if (expr_equal(constraint, negation)) {
return make_const(expr->dtype, false);
}
}
return NullOpt;
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<EQNode>();
PrimExpr const_res = TryConstFold<EQ>(op->a, op->b);
if (const_res.defined()) return const_res;
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

// Pattern var to match any expression
PVar<PrimExpr> x, y;
Expand Down Expand Up @@ -1344,6 +1378,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
op = ret.as<LTNode>();
PrimExpr const_res = TryConstFold<LT>(op->a, op->b);
if (const_res.defined()) return const_res;
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
Expand Down Expand Up @@ -1475,6 +1510,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
op = ret.as<NotNode>();
PrimExpr const_res = TryConstFold<Not>(op->a);
if (const_res.defined()) return const_res;
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

// Pattern var to match any expression
PVar<PrimExpr> x, y;
PVar<int> lanes;
Expand All @@ -1499,6 +1536,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
op = ret.as<AndNode>();
PrimExpr const_res = TryConstFold<And>(op->a, op->b);
if (const_res.defined()) return const_res;
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

// Pattern var to match any expression
PVar<PrimExpr> x, y;
Expand Down Expand Up @@ -1538,6 +1576,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
op = ret.as<OrNode>();
PrimExpr const_res = TryConstFold<Or>(op->a, op->b);
if (const_res.defined()) return const_res;
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

// Pattern var to match any expression
PVar<PrimExpr> x, y;
Expand Down Expand Up @@ -1602,13 +1641,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
return op->args[0] << op->args[1];
}
}
ExprDeepEqual expr_equal;
if (op->op.same_as(tir::builtin::likely())) {
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (expr_equal(constraint, op->args[0])) {
return make_const(op->dtype, true);
}
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (auto match = TryMatchLiteralConstraint(op->args[0])) {
return match.value();
}
}
return ret;
Expand Down
9 changes: 9 additions & 0 deletions src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
*/
bool CanInlineLet(const LetNode* op);

/*! \brief Internal function to apply constraints
*
* Tests whether the expression is known to be true or false based
* on existing constraints. If the expression or its negation
* matches a constraint, return the boolean it should be replaced
* with. Otherwise, return false.
*/
Optional<PrimExpr> TryMatchLiteralConstraint(const PrimExpr& expr) const;

private:
// Whether x >= val
bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) {
Expand Down
Loading

0 comments on commit c78539c

Please sign in to comment.