Skip to content

Commit

Permalink
rewrite ExprRewriter and convert fast_math to use it
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart authored and Matthew Brookhart committed Mar 31, 2020
1 parent 5e92c9c commit bc9e42e
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 47 deletions.
91 changes: 57 additions & 34 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,6 @@ class ScopeMutator : public ::tvm::relay::ExprMutator {
virtual bool CheckVisited(const Expr& expr);
};

#define EXPR_REWRITER_VISIT_DEFAULT \
{ return Rewrite_(pre, post); }

#define EXPR_REWRITER_REWRITE_DEFAULT \
{ return post; }

/*! \brief A non-iterating Expression Rewriter
*
* ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
Expand All @@ -326,17 +320,43 @@ class ScopeMutator : public ::tvm::relay::ExprMutator {
* ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
* graph rewriting.
*/
class ExprRewriter : private ExprFunctor<Expr(const Expr&, const Expr&)> {
#define RELAY_EXPR_REWRITER_DISPATCH(OP) \
vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr& post) { \
return self->Rewrite_(static_cast<const OP*>(n.get()), post); \
});

#define EXPR_REWRITER_REWRITE_DEFAULT \
{ return post; }

class ExprRewriter {
private:
using TSelf = ExprRewriter;
using FType = tvm::NodeFunctor<Expr(const ObjectRef& n, TSelf* self, const Expr& post)>;

public:
/*! \brief Rewrite a node given the orginal form and the form with modified inputs
*
* Uses ExprFunctor for vtable access.
*
* Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be
* able to rewrite the op only with data about the original node `pre` and the same node with
* modified inputs `post` and should not recurse.
/*! \brief virtual destructor */
virtual ~ExprRewriter() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
Expr operator()(const Expr& pre, const Expr& post) {
return Rewrite(pre, post);
}
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual Expr Rewrite(const Expr& pre, const Expr& post) { return this->VisitExpr(pre, post); }
virtual Expr Rewrite(const Expr& pre, const Expr& post) {
CHECK(pre.defined());
static FType vtable = InitVTable();
return vtable(pre, this, post);
}
// Functions that can be overriden by subclass, should not recurse
virtual Expr Rewrite_(const VarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const GlobalVarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const ConstantNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
Expand All @@ -355,24 +375,27 @@ class ExprRewriter : private ExprFunctor<Expr(const Expr&, const Expr&)> {
virtual Expr Rewrite_(const MatchNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;

private:
Expr VisitExpr(const Expr& pre, const Expr& post) final {
return ExprFunctor::VisitExpr(pre, post);
};
Expr VisitExpr_(const VarNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const GlobalVarNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const ConstantNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const TupleNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const FunctionNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const CallNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const LetNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const IfNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const OpNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const TupleGetItemNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const RefCreateNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const RefReadNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const RefWriteNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const ConstructorNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
Expr VisitExpr_(const MatchNode* pre, const Expr& post) final EXPR_REWRITER_VISIT_DEFAULT;
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_EXPR_REWRITER_DISPATCH(ConstantNode);
RELAY_EXPR_REWRITER_DISPATCH(TupleNode);
RELAY_EXPR_REWRITER_DISPATCH(VarNode);
RELAY_EXPR_REWRITER_DISPATCH(GlobalVarNode);
RELAY_EXPR_REWRITER_DISPATCH(FunctionNode);
RELAY_EXPR_REWRITER_DISPATCH(CallNode);
RELAY_EXPR_REWRITER_DISPATCH(LetNode);
RELAY_EXPR_REWRITER_DISPATCH(IfNode);
RELAY_EXPR_REWRITER_DISPATCH(OpNode);
RELAY_EXPR_REWRITER_DISPATCH(TupleGetItemNode);
RELAY_EXPR_REWRITER_DISPATCH(RefCreateNode);
RELAY_EXPR_REWRITER_DISPATCH(RefReadNode);
RELAY_EXPR_REWRITER_DISPATCH(RefWriteNode);
RELAY_EXPR_REWRITER_DISPATCH(ConstructorNode);
RELAY_EXPR_REWRITER_DISPATCH(MatchNode);
return vtable;
}
};

/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
Expand All @@ -382,7 +405,7 @@ class ExprRewriter : private ExprFunctor<Expr(const Expr&, const Expr&)> {
* PostOrderRewrite provides the original node and the node with altered inputs for use by the
* ExprRewriter.
*/
Expr PostOrderRewrite(const Expr& expr, const ExprRewriter& rewriter);
Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
Expand Down
8 changes: 4 additions & 4 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,16 @@ Expr ScopeMutator::Mutate(const Expr& expr) {

class PostOrderRewriter : protected ScopeMutator {
public:
PostOrderRewriter(const ExprRewriter& rewriter) : rewriter_(rewriter) {}
PostOrderRewriter(ExprRewriter* rewriter) : rewriter_(rewriter) {}
Expr VisitExpr(const Expr& expr) final {
auto post = ExprMutator::VisitExpr(expr);
return rewriter_.Rewrite(expr, post);
return rewriter_->Rewrite(expr, post);
}
protected:
ExprRewriter rewriter_;
ExprRewriter* rewriter_;
};

Expr PostOrderRewrite(const Expr& expr, const ExprRewriter& rewriter) {
Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter) {
return PostOrderRewriter(rewriter).VisitExpr(expr);
}

Expand Down
18 changes: 9 additions & 9 deletions src/relay/transforms/fast_math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,19 @@
namespace tvm {
namespace relay {

class FastMathMutator : public ExprMutator {
class FastMathMutator : public ExprRewriter {
public:
FastMathMutator()
: exp_op_(Op::Get("exp")),
tanh_op_(Op::Get("tanh")) {}

Expr VisitExpr_(const CallNode* n) {
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op == exp_op_) {
return FastExp(new_n.as<CallNode>()->args[0]);
} else if (n->op == tanh_op_) {
return FastTanh(new_n.as<CallNode>()->args[0]);
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (pre->op == exp_op_) {
return FastExp(post.as<CallNode>()->args[0]);
} else if (pre->op == tanh_op_) {
return FastTanh(post.as<CallNode>()->args[0]);
}
return new_n;
return post;
}

private:
Expand All @@ -56,7 +55,8 @@ class FastMathMutator : public ExprMutator {
};

Expr FastMath(const Expr& e) {
return FastMathMutator().Mutate(e);
auto rewriter = FastMathMutator();
return PostOrderRewrite(e, &rewriter);
}

namespace transform {
Expand Down

0 comments on commit bc9e42e

Please sign in to comment.