Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] Non-recursive Graph Vistor and Rewriter #4886

Merged
merged 34 commits into from Apr 3, 2020
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
bd14b04
First pass a defining a non-recursive Graph Vistor and Rewriter
Feb 11, 2020
bfabd87
Make CalcDep from Dead Code Elimination non-recursive
Mar 2, 2020
fd03421
Partially working, not passing all tests yet
Mar 12, 2020
31d0664
Refactor
Mar 14, 2020
7a68f0e
improve comments
Mar 18, 2020
6e4f7eb
respond to review comments on comments
Mar 20, 2020
0a80325
Fix a problem with default recursion for dataflow nodes
Mar 20, 2020
91747ef
implement ScopeMutator
mbrookhart Mar 27, 2020
2d5ae42
convert forward_rewrite to ScopeMutator, remove DataflowMutator
mbrookhart Mar 27, 2020
983f13c
rewrite ExprRewriter and convert fast_math to use it
mbrookhart Mar 28, 2020
5f3addf
switch BiasAddSimplifier to ExprRewriter
mbrookhart Mar 28, 2020
f188e7e
respond to review comments
Apr 1, 2020
cd2084e
fix a typo in the iterative looping
mbrookhart Apr 1, 2020
588daf3
add a regression test for GetExprRefCount issue
mbrookhart Apr 2, 2020
b1520c0
Normalize naming
mbrookhart Apr 2, 2020
62ca9f5
fix lint
mbrookhart Apr 2, 2020
25200bc
First pass a defining a non-recursive Graph Vistor and Rewriter
Feb 11, 2020
6e89865
Make CalcDep from Dead Code Elimination non-recursive
Mar 2, 2020
5cae577
Partially working, not passing all tests yet
Mar 12, 2020
55d6888
Refactor
Mar 14, 2020
6b95a91
improve comments
Mar 18, 2020
27d0760
respond to review comments on comments
Mar 20, 2020
bac5337
Fix a problem with default recursion for dataflow nodes
Mar 20, 2020
0af7763
implement ScopeMutator
mbrookhart Mar 27, 2020
4934eae
convert forward_rewrite to ScopeMutator, remove DataflowMutator
mbrookhart Mar 27, 2020
e08119c
rewrite ExprRewriter and convert fast_math to use it
mbrookhart Mar 28, 2020
b1a496b
switch BiasAddSimplifier to ExprRewriter
mbrookhart Mar 28, 2020
c1fc3e8
respond to review comments
Apr 1, 2020
ada8bd1
fix a typo in the iterative looping
mbrookhart Apr 1, 2020
509a087
add a regression test for GetExprRefCount issue
mbrookhart Apr 2, 2020
a7c8429
Normalize naming
mbrookhart Apr 2, 2020
ae3c62e
fix lint
mbrookhart Apr 2, 2020
dc89877
Merge branch 'mbrookhart/GraphVisitor' of github.com:mbrookhart/incub…
mbrookhart Apr 3, 2020
517e83d
respond to review comments
mbrookhart Apr 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/relay/analysis.h
Expand Up @@ -30,6 +30,7 @@
#include <tvm/ir/module.h>
#include <tvm/relay/type.h>
#include <string>
#include <unordered_map>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -225,6 +226,16 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
*/
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);

/*!
* \brief Get reference counter of each internal ExprNode in body.
*
* \param body The body expression.
*
* \return The reference count mapping.
*/
TVM_DLL std::unordered_map<const Object*, size_t>
GetExprRefCount(const Expr& body);

} // namespace relay
} // namespace tvm

Expand Down
182 changes: 182 additions & 0 deletions include/tvm/relay/expr_functor.h
Expand Up @@ -232,6 +232,188 @@ class ExprMutator
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_;
};

/*!
* \brief A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
*
* MixedModeVisitor treats Expr as dataflow graph, and visits in post-DFS order
*
* MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
*/
class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
public:
/*! \brief The constructor of MixedModeVisitor
* \param visit_limit The number of times to allow visitation to a note. Usually 1, ocassionally
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
* higher (i.e., 2 for dead code elimiation), limited to 10 as a sanity check.
*/
MixedModeVisitor(int visit_limit = 1);
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief VisitExpr is finalized to preserve call expansion of dataflow regions
*/
void VisitExpr(const Expr& expr) final;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;

protected:
/*!
* \brief A function to apply when reaching a leaf of the graph non-recursively
*/
virtual void VisitLeaf(const Expr& expr);
/*!
* \brief A function to determine if an expression has already been visited or needs to be
* re-visited
*/
virtual bool CheckVisited(const Expr& expr);
/*!
* \brief The max number of times to visit a node
*/
size_t visit_limit_;
};

/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
* MixedModeMutator treats Expr as dataflow graph, and only Rewrites each Expr once.
* The mutated results are memoized in a map and reused so that
* local transformation on the dataflow preserves the graph structure.
*
* MixedModeMutator provides the same recursive API as ExprMutator, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
*
* Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive behavior.
*/
class MixedModeMutator : public ::tvm::relay::ExprMutator {
public:
Expr VisitExpr(const Expr& expr) final;
virtual Expr DispatchVisitExpr(const Expr& expr);
Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
/*!
* Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
* able to rewrite the op only with data about the original node `pre` and the same node with
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
* modified inputs `post` and should not recurse.
*
* \param pre The expression node before rewriting.
* \param post The expression with rewritten inputs.
*/
virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;}
virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; }
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }

protected:
/*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
* changed inputs.
*/
template <typename T>
Expr Rewrite(const T* op) {
Expr post = ExprMutator::VisitExpr_(op);
return Rewrite_(op, post);
}

virtual void VisitLeaf(const Expr& expr);
virtual bool CheckVisited(const Expr& expr);
};

/*! \brief A non-iterating Expression Rewriter
*
* ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
* The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
* node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
* ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
* graph rewriting.
*/
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
#define RELAY_EXPR_REWRITER_DISPATCH(OP) \
vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr& post) { \
return self->Rewrite_(static_cast<const OP*>(n.get()), post); \
});

#define EXPR_REWRITER_REWRITE_DEFAULT \
{ return post; }

class ExprRewriter {
private:
using TSelf = ExprRewriter;
using FType = tvm::NodeFunctor<Expr(const ObjectRef& n, TSelf* self, const Expr& post)>;

public:
/*! \brief virtual destructor */
virtual ~ExprRewriter() {}
/*!
* \brief Same as call.
* \param pre The expression node before rewriting.
* \param post The expression node with rewritten inputs.
* \return The result of the call
*/
Expr operator()(const Expr& pre, const Expr& post) {
return Rewrite(pre, post);
}
/*!
* \brief The functor call.
* \param pre The expression node before rewriting.
* \param post The expression node with rewritten inputs.
* \return The result of the call
*/
virtual Expr Rewrite(const Expr& pre, const Expr& post) {
CHECK(pre.defined());
static FType vtable = InitVTable();
return vtable(pre, this, post);
}
// Functions that can be overriden by subclass, should not recurse
virtual Expr Rewrite_(const VarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const GlobalVarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const ConstantNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const FunctionNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const CallNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const LetNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const IfNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const OpNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const TupleGetItemNode* pre,
const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const RefCreateNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const RefReadNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const RefWriteNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const ConstructorNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const MatchNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;

private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_EXPR_REWRITER_DISPATCH(ConstantNode);
RELAY_EXPR_REWRITER_DISPATCH(TupleNode);
RELAY_EXPR_REWRITER_DISPATCH(VarNode);
RELAY_EXPR_REWRITER_DISPATCH(GlobalVarNode);
RELAY_EXPR_REWRITER_DISPATCH(FunctionNode);
RELAY_EXPR_REWRITER_DISPATCH(CallNode);
RELAY_EXPR_REWRITER_DISPATCH(LetNode);
RELAY_EXPR_REWRITER_DISPATCH(IfNode);
RELAY_EXPR_REWRITER_DISPATCH(OpNode);
RELAY_EXPR_REWRITER_DISPATCH(TupleGetItemNode);
RELAY_EXPR_REWRITER_DISPATCH(RefCreateNode);
RELAY_EXPR_REWRITER_DISPATCH(RefReadNode);
RELAY_EXPR_REWRITER_DISPATCH(RefWriteNode);
RELAY_EXPR_REWRITER_DISPATCH(ConstructorNode);
RELAY_EXPR_REWRITER_DISPATCH(MatchNode);
return vtable;
}
};

/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
* PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call,
* PostOrderRewrite provides the original node and the node with altered inputs for use by the
* ExprRewriter.
*/
Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/util.cc
Expand Up @@ -330,7 +330,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars")
*/
std::unordered_map<const Object*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
class ExprRefCounter : private MixedModeVisitor {
public:
std::unordered_map<const Object*, size_t>
Get(const Expr& body) {
Expand Down