Skip to content

Commit

Permalink
switch BiasAddSimplifier to ExprRewriter
Browse files Browse the repository at this point in the history
fix a clang warning

fix cpp lint

fix doc param error
  • Loading branch information
mbrookhart authored and Matthew Brookhart committed Mar 31, 2020
1 parent bc9e42e commit 9c6a04d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
20 changes: 10 additions & 10 deletions include/tvm/relay/expr_functor.h
Expand Up @@ -293,9 +293,9 @@ class ScopeMutator : public ::tvm::relay::ExprMutator {
* able to rewrite the op only with data about the original node `pre` and the same node with
* modified inputs `post` and should not recurse.
*/
virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;};
virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; };
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; };
virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;}
virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; }
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }

protected:
/*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
Expand All @@ -320,9 +320,9 @@ class ScopeMutator : public ::tvm::relay::ExprMutator {
* ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
* graph rewriting.
*/
#define RELAY_EXPR_REWRITER_DISPATCH(OP) \
vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr& post) { \
return self->Rewrite_(static_cast<const OP*>(n.get()), post); \
#define RELAY_EXPR_REWRITER_DISPATCH(OP) \
vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr& post) { \
return self->Rewrite_(static_cast<const OP*>(n.get()), post); \
});

#define EXPR_REWRITER_REWRITE_DEFAULT \
Expand All @@ -338,17 +338,17 @@ class ExprRewriter {
virtual ~ExprRewriter() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \param pre The expression node before rewriting.
* \param post The expression node with rewritten inputs.
* \return The result of the call
*/
Expr operator()(const Expr& pre, const Expr& post) {
return Rewrite(pre, post);
}
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \param pre The expression node before rewriting.
* \param post The expression node with rewritten inputs.
* \return The result of the call
*/
virtual Expr Rewrite(const Expr& pre, const Expr& post) {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/expr_functor.cc
Expand Up @@ -169,9 +169,9 @@ Expr ScopeMutator::Mutate(const Expr& expr) {

class PostOrderRewriter : protected ScopeMutator {
public:
PostOrderRewriter(ExprRewriter* rewriter) : rewriter_(rewriter) {}
explicit PostOrderRewriter(ExprRewriter* rewriter) : rewriter_(rewriter) {}
Expr VisitExpr(const Expr& expr) final {
auto post = ExprMutator::VisitExpr(expr);
auto post = ExprFunctor::VisitExpr(expr);
return rewriter_->Rewrite(expr, post);
}
protected:
Expand Down
9 changes: 5 additions & 4 deletions src/relay/transforms/canonicalize_ops.cc
Expand Up @@ -32,12 +32,12 @@
namespace tvm {
namespace relay {

class BiasAddSimplifier : public ExprMutator {
class BiasAddSimplifier : public ExprRewriter {
public:
BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {}

Expr VisitExpr_(const CallNode* n) {
auto new_n = ExprMutator::VisitExpr_(n);
Expr Rewrite_(const CallNode* n, const Expr& post) override {
auto new_n = post;
if (n->op == bias_add_op_) {
Call call = Downcast<Call>(new_n);
CHECK_EQ(call->args.size(), 2);
Expand All @@ -63,7 +63,8 @@ class BiasAddSimplifier : public ExprMutator {
};

Expr CanonicalizeOps(const Expr& e) {
return BiasAddSimplifier().Mutate(e);
auto rewriter = BiasAddSimplifier();
return PostOrderRewrite(e, &rewriter);
}

namespace transform {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/transforms/dead_code.cc
Expand Up @@ -92,7 +92,7 @@ class Eliminator : private ExprMutator {
};

// calculate the dependency graph from expression
class CalcDep : private DataflowVisitor {
class CalcDep : protected DataflowVisitor {
public:
static Expr Eliminate(const Expr& e, bool inline_once) {
FindDef fd;
Expand All @@ -109,6 +109,8 @@ class CalcDep : private DataflowVisitor {
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;

using DataflowVisitor::VisitExpr_;

void VisitLeaf(const Expr& e) final {
visit_counter_[e.get()]++;
// The dce code seprate variable into three parts:
Expand Down

0 comments on commit 9c6a04d

Please sign in to comment.