Skip to content

Commit

Permalink
[REFACTOR][TIR] Introduce ExprDeepEqual, Remove IRDeepCompare
Browse files Browse the repository at this point in the history
This PR introduces ExprDeepEqual which reuses the StructuralEqual infra.
We migrated the usecases of ir_pass::Equal to ExprDeepEqual and StructuralEqual.
  • Loading branch information
tqchen committed Apr 1, 2020
1 parent e722301 commit 266d914
Show file tree
Hide file tree
Showing 45 changed files with 419 additions and 641 deletions.
9 changes: 8 additions & 1 deletion docs/api/python/tir.rst
Expand Up @@ -24,10 +24,17 @@ tvm.tir
:autosummary:



tvm.tir.transform
-----------------
.. automodule:: tvm.tir.transform
:members:
:imported-members:
:autosummary:


tvm.tir.analysis
----------------
.. automodule:: tvm.tir.analysis
:members:
:imported-members:
:autosummary:
54 changes: 54 additions & 0 deletions include/tvm/tir/analysis.h
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/tir/analysis.h
* \brief Analysis utilitie and passes for TIR.
*/
#ifndef TVM_TIR_ANALYSIS_H_
#define TVM_TIR_ANALYSIS_H_

#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>

namespace tvm {
namespace tir {

/*!
* \brief Compare two expressions recursively and check if they are equal
* to each other without var remapping.
*
* This function do not remap variable bindings, it will not
* return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y).
*
* Use StructuralEqual for such cases.
*
* Due to the restriction of not remapping variables, this function can run
* faster than StructuralEqual and can be used as an utility function during arithmetic
* simplifications.
*
* \sa StructuralEqual
*/
struct ExprDeepEqual {
public:
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_
11 changes: 11 additions & 0 deletions include/tvm/tir/expr.h
Expand Up @@ -920,6 +920,17 @@ class FunctionBaseNode : public Object {
virtual const std::string& func_name() const = 0;
/*! \return the number of outputs of this function */
virtual int num_outputs() const = 0;

// fall back to pointer equality now before refactor.
bool SEqualReduce(const FunctionBaseNode* other, SEqualReducer equal) const {
return this == other;
}

void SHashReduce(SHashReducer hash_reduce) const {
}

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
};

/*! \brief reference to a function */
Expand Down
29 changes: 0 additions & 29 deletions include/tvm/tir/ir_pass.h
Expand Up @@ -76,35 +76,6 @@ Stmt CanonicalSimplify(Stmt stmt,
TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
Map<Var, Range> vrange = Map<Var, Range>());

/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
TVM_DLL bool Equal(const PrimExpr& lhs, const PrimExpr& rhs);

/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
bool Equal(const Stmt& lhs, const Stmt& rhs);

/*!
* \brief Deep compare lhs and rhs.
*
* If you only want equality comparison, use Equal
* which will also tie definitions. The compare mode
* will give order of expression in total order.
*
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
int Compare(const PrimExpr& lhs, const PrimExpr& rhs);

/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/hybrid/calls.py
Expand Up @@ -22,7 +22,6 @@
from tvm.ir.container import Array
from tvm import target as _tgt
from tvm.tir import expr as _expr
from tvm.tir import ir_pass
from tvm.tir import call_pure_intrin
from tvm.tir.stmt import For

Expand All @@ -47,7 +46,7 @@ def _range(annotation, args):
else:
_internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
low, ext = args[0], args[1]
if not ir_pass.Equal(low, const(0, dtype='int32')):
if not tvm.tir.analysis.expr_deep_equal(low, const(0, dtype='int32')):
ext = ext - low
for_type = LOOP_INTRIN[annotation]
iter_var = None
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/hybrid/parser.py
Expand Up @@ -56,7 +56,7 @@ def concat_list_to_block(lst):
def visit_list_to_block(visit, lst):
"""Visit and concatenate a list of Python IR nodes to HalideIR Block"""
lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
lst = [stmt for stmt in lst if not tvm.ir.structural_equal(stmt, util.make_nop())]
if not lst:
return util.make_nop()
return concat_list_to_block(lst)
Expand Down Expand Up @@ -178,7 +178,7 @@ def add_symbol(self, key, ty, val): #pylint: disable=invalid-name
self.binds[val.var.name] = val
return
val_ = self.binds[val.var.name]
_internal_assert(_ir_pass.Equal(val_.dom.extent, val.dom.extent),
_internal_assert(tvm.tir.analysis.expr_deep_equal(val_.dom.extent, val.dom.extent),
"Thread extents should be uniform!")
self.symbols[key] = ty, val_

Expand Down Expand Up @@ -525,7 +525,7 @@ def visit_For(self, node):
if iter_var is None:
_internal_assert(for_type is not None, "The loop iterating function parse error!")
offset = iter_var = tvm.te.var(_name)
if not _ir_pass.Equal(low, tvm.runtime.const(0, 'int32')):
if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, 'int32')):
offset = iter_var + low
self.add_symbol(_name, Symbol.LoopVar, offset)
_body = visit_list_to_block(self.visit, node.body)
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/ir/base.py
Expand Up @@ -198,6 +198,8 @@ def structural_equal(lhs, rhs, map_free_vars=False):
structural_hash
assert_strucural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
return bool(tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, False, map_free_vars))

Expand Down Expand Up @@ -225,6 +227,8 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
--------
structural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, True, map_free_vars)

Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Expand Up @@ -46,3 +46,4 @@
from . import ir_builder
from . import ir_pass
from . import transform
from . import analysis
20 changes: 20 additions & 0 deletions python/tvm/tir/analysis/__init__.py
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace of all TIR analysis utils."""
# pylint: disable=wildcard-import, invalid-name

from .analysis import *
21 changes: 21 additions & 0 deletions python/tvm/tir/analysis/_ffi_api.py
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.tir.analysis"""
import tvm._ffi


tvm._ffi._init_api("tir.analysis", __name__)
57 changes: 57 additions & 0 deletions python/tvm/tir/analysis/analysis.py
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Wrapping existing analysis utils."""
# pylint: disable=invalid-name

from . import _ffi_api


def expr_deep_equal(lhs, rhs):
"""Deeply compare two nested expressions.
Parameters
----------
lhs : PrimExpr
The left operand.
rhs : PrimExpr
The right operand.
Returns
-------
result : bool
The comparison result
Note
----
This function do not remap variable bindings, it will not
return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y).
Use py:func:`tvm.ir.structural_equal` to handle structural variable remapping.
Due to the restriction of not remapping variables, this function can run
faster than StructuralEqual and can be used as an utility function during arithmetic
simplifications.
Always consider py:func:`tvm.ir.structural_equal` first, which handles
the structural remapping.
See Also
--------
tvm.ir.structural_equal
"""
return _ffi_api.expr_deep_equal(lhs, rhs)
4 changes: 3 additions & 1 deletion src/arith/canonical_simplify.cc
Expand Up @@ -23,6 +23,8 @@
*/
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/tir/analysis.h>

#include "const_fold.h"
#include "pattern_match.h"
#include "rewrite_simplify.h"
Expand Down Expand Up @@ -157,7 +159,7 @@ class SplitExpr : public PrimExpr {

inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const {
if (index.same_as(other->index)) return true;
return tir::Equal(index, other->index);
return tir::ExprDeepEqual()(index, other->index);
}

inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const {
Expand Down
3 changes: 2 additions & 1 deletion src/arith/const_int_bound.cc
Expand Up @@ -138,10 +138,11 @@ class ConstIntBoundAnalyzer::Impl :

Entry VisitExpr(const PrimExpr& expr) final {
Entry res = ExprFunctor::VisitExpr(expr);
tir::ExprDeepEqual equal;
// a linear search over additional info
// assume we won't have a lot of conditions
for (const BoundInfo& info : additional_info_) {
if (tir::Equal(expr, info.expr)) {
if (equal(expr, info.expr)) {
res = Intersect(res, info.bound);
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/arith/pattern_match.h
Expand Up @@ -66,6 +66,7 @@
#define TVM_ARITH_PATTERN_MATCH_H_

#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tuple>
#include "const_fold.h"

Expand Down Expand Up @@ -135,7 +136,7 @@ class PEqualChecker<PrimExpr> {
public:
bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
if (lhs.same_as(rhs)) return true;
return tir::Equal(lhs, rhs);
return tir::ExprDeepEqual()(lhs, rhs);
}
};

Expand Down
9 changes: 5 additions & 4 deletions src/arith/rewrite_simplify.cc
Expand Up @@ -101,11 +101,11 @@ TryCompare(const PrimExpr& x, int64_t val) {
}

void RewriteSimplifier::Impl::
Update(const Var& var, const PrimExpr& info, bool override) {
if (!override) {
Update(const Var& var, const PrimExpr& info, bool can_override) {
if (!can_override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(Equal(it->second, info))
CHECK(ExprDeepEqual()(it->second, info))
<< "Trying to update var \'" << var << "\'"
<< " with a different value: "
<< "original=" << it->second
Expand Down Expand Up @@ -1716,10 +1716,11 @@ VisitExpr_(const CallNode* op) {
return op->args[0] & op->args[1];
}
}
ExprDeepEqual expr_equal;
if (op->is_intrinsic(CallNode::likely)) {
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (Equal(constraint, op->args[0])) {
if (expr_equal(constraint, op->args[0])) {
return make_const(op->dtype, true);
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/arith/stmt_simplify.cc
Expand Up @@ -23,7 +23,9 @@
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/arith/analyzer.h>

#include <tvm/tir/op.h>
#include <tvm/arith/analyzer.h>
#include "ir_mutator_with_analyzer.h"
Expand Down Expand Up @@ -83,7 +85,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
op = stmt.as<StoreNode>();
if (const LoadNode* load = op->value.as<LoadNode>()) {
if (load->buffer_var.same_as(op->buffer_var) &&
Equal(load->index, op->index)) {
tir::ExprDeepEqual()(load->index, op->index)) {
return EvaluateNode::make(0);
}
}
Expand Down
1 change: 0 additions & 1 deletion src/node/structural_equal.cc
Expand Up @@ -225,7 +225,6 @@ class RemapVarSEqualHandler :
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_rhs_;
};


TVM_REGISTER_GLOBAL("node.StructuralEqual")
.set_body_typed([](const ObjectRef& lhs,
const ObjectRef& rhs,
Expand Down

0 comments on commit 266d914

Please sign in to comment.