Skip to content

Commit

Permalink
First pass a defining a non-recursive Graph Vistor and Rewriter
Browse files Browse the repository at this point in the history
autoformat

remove a currently empty test until testing is solidfied
  • Loading branch information
Matthew Brookhart committed Feb 28, 2020
1 parent a449d8b commit 9c952ee
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 13 deletions.
91 changes: 90 additions & 1 deletion include/tvm/relay/expr_functor.h
Expand Up @@ -28,6 +28,7 @@
#include <tvm/node/functor.h>
#include <tvm/ir/error.h>

#include <stack>
#include <string>
#include <utility>
#include <unordered_map>
Expand Down Expand Up @@ -197,7 +198,7 @@ class ExprMutator
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) {
virtual Expr Mutate(const Expr& expr) {
return this->VisitExpr(expr);
}
Expr VisitExpr(const Expr& expr) override;
Expand Down Expand Up @@ -233,6 +234,94 @@ class ExprMutator
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_;
};

/*!
* \brief A wrapper around ExprFunctor which traverses the Graph Normal AST.
*
* PostOrderGraphVisitor treats Expr as dataflow graph.
* PostOrderGraphVisitor provides utitilies for doing a Depth-first
* pre or post order traversal of the graph without recursion.
* It accepts a std::function as part of it's construction which is used
* to do the actual operatiuons on the graph.
*/
class PostOrderGraphVisitor : public ::tvm::relay::ExprFunctor<bool(const Expr&)> {
public:
PostOrderGraphVisitor() : visitor_([](const Expr&) {}), visit_count_(1) {}
PostOrderGraphVisitor(const std::function<void(const Expr&)>& visitor, int visit_count = 1)
: visitor_(visitor) {
CHECK(visit_count > 0) << "GraphMutator visit count must be greater than 0";
CHECK(visit_count < 10) << "GraphMutator visit count must be less than 10";
visit_count_ = visit_count;
}
bool VisitExpr(const Expr& expr) override;
bool VisitExpr_(const VarNode* op) override;
bool VisitExpr_(const ConstantNode* op) override;
bool VisitExpr_(const GlobalVarNode* op) override;
bool VisitExpr_(const OpNode* op) override;
bool VisitExpr_(const TupleNode* op) override;
bool VisitExpr_(const FunctionNode* op) override;
bool VisitExpr_(const CallNode* call_node) override;
bool VisitExpr_(const LetNode* op) override;
bool VisitExpr_(const IfNode* op) override;
bool VisitExpr_(const TupleGetItemNode* op) override;
bool VisitExpr_(const RefCreateNode* op) override;
bool VisitExpr_(const RefReadNode* op) override;
bool VisitExpr_(const RefWriteNode* op) override;
bool VisitExpr_(const ConstructorNode* op) override;
bool VisitExpr_(const MatchNode* op) override;

protected:
/* \brief Function to push a note to the stack if it hasn't been visited
* Also returns false if the node hasn't been visited */
bool PushToStack(const Expr& expr);
/*! \brief std::function used to process nodes. */
std::function<void(const Expr&)> visitor_;
/*! \brief Number of times fo visit each node. */
size_t visit_count_;
// Internal visiting counter
std::unordered_map<const Object*, size_t> visit_counter_;
/*! \brief Internal Manually Managed Stack.
* The stack contains a pair of <Expr, bool> where the boolean indicates
* whether or not we've already checked the status of this nodes inputs */
std::stack<std::pair<Expr, bool>> stack_;
};

/*!
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprRewriter treats Expr as dataflow graph, and only Rewrites each Expr once.
* The mutated results are memoized in a map and reused so that
* local transformation on the dataflow preserves the graph structure.
* ExprRewriter does not iterate over the graph, it assumes another method,
* like PostOrderGraphVisitor, is passing expressions in an appropriate order
*/
class ExprRewriter : protected ::tvm::relay::ExprMutator {
public:
ExprRewriter(std::function<void(const Expr&)> f) : visitor(PostOrderGraphVisitor(f)) {}
ExprRewriter()
: visitor(PostOrderGraphVisitor([this](const Expr& expr) { ExprMutator::VisitExpr(expr); })) {
}
/*!
* \brief Rewrite is alias for VisitExpr
* Rewrite effective does the same thing as ExprMutator's Mutate,
* except it prepends it by a PostOrderGraphTraversal to ensure the inputs have been processed
* \return expr.
*/
virtual Expr Rewrite(const Expr& expr) {
if (memo_.count(expr)) {
return memo_[expr];
} else {
this->visitor.VisitExpr(expr);
Expr ret = this->VisitExpr(expr);
memo_[expr] = ret;
return ret;
}
}

protected:
Expr Mutate(const Expr& expr) override;
PostOrderGraphVisitor visitor;
};

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
Expand Down
127 changes: 125 additions & 2 deletions src/relay/ir/expr_functor.cc
Expand Up @@ -32,6 +32,129 @@
namespace tvm {
namespace relay {

Expr ExprRewriter::Mutate(const Expr& expr) {
CHECK(this->memo_.count(expr)) << "ExprRewriter::Mutate called on a Node with unprocessed inputs";
return this->memo_[expr];
}

bool PostOrderGraphVisitor::PushToStack(const Expr& expr) {
bool out = true;
if (visit_counter_[expr.get()] < visit_count_) {
stack_.push({expr, false});
out = false;
} else {
visit_counter_[expr.get()]++;
}
// return true if this node was already visited,
// or false if we had to push it onto the stack
return out;
}

bool PostOrderGraphVisitor::VisitExpr(const Expr& expr) {
PushToStack(expr);
while (stack_.size() > 0) {
std::pair<Expr, bool>& current = stack_.top();
if (visit_counter_[current.first.get()] < visit_count_) {
if (current.second // If we have already checked this node
|| current.first.as<TempExprNode>() // Or it's temporary
|| ExprFunctor::VisitExpr(current.first)) { // or we've already visited it's inputs
// Do post order visitation
visitor_(current.first);
visit_counter_[current.first.get()]++;
stack_.pop();
} else {
// Otherwise. the VisitExpr function just pushed univisted children onto the stack
// The next time we see this item on the stack, it's children must have been visted
current.second = true;
}
} else {
visit_counter_[current.first.get()]++;
stack_.pop();
}
}
// This is just to match the template API of ExprFunctor
return true;
}

bool PostOrderGraphVisitor::VisitExpr_(const VarNode* op) { return true; }
bool PostOrderGraphVisitor::VisitExpr_(const ConstantNode* op) { return true; }
bool PostOrderGraphVisitor::VisitExpr_(const GlobalVarNode* op) { return true; }
bool PostOrderGraphVisitor::VisitExpr_(const OpNode* op) { return true; }
bool PostOrderGraphVisitor::VisitExpr_(const TupleNode* op) {
bool children_processed = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
children_processed &= PushToStack(*it);
}
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const FunctionNode* op) {
bool children_processed = true;
// push the children to the stack in reverse order
// to match recursive processing order
children_processed &= PushToStack(op->body);
for (auto it = op->params.rbegin(); it != op->params.rend(); ++it) {
children_processed &= PushToStack(*it);
}
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const CallNode* call_node) {
bool children_processed = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = call_node->args.rbegin(); it != call_node->args.rend(); ++it) {
children_processed &= PushToStack(*it);
}
children_processed &= PushToStack(call_node->op);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const LetNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->body);
children_processed &= PushToStack(op->value);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const IfNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->false_branch);
children_processed &= PushToStack(op->true_branch);
children_processed &= PushToStack(op->cond);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const TupleGetItemNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->tuple);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const RefCreateNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->value);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const RefReadNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->ref);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const RefWriteNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->value);
children_processed &= PushToStack(op->ref);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const ConstructorNode* op) { return true; }
bool PostOrderGraphVisitor::VisitExpr_(const MatchNode* op) {
bool children_processed = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->clauses.rbegin(); it != op->clauses.rend(); ++it) {
children_processed &= PushToStack((*it)->rhs);
}
children_processed &= PushToStack(op->data);
return children_processed;
}

Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr);
if (it != this->memo_.end()) {
Expand Down Expand Up @@ -211,12 +334,12 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
for (const Clause& p : m->clauses) {
clauses.push_back(VisitClause(p));
}
return MatchNode::make(VisitExpr(m->data), clauses, m->complete);
return MatchNode::make(this->Mutate(m->data), clauses, m->complete);
}

Clause ExprMutator::VisitClause(const Clause& c) {
Pattern p = VisitPattern(c->lhs);
return ClauseNode::make(p, VisitExpr(c->rhs));
return ClauseNode::make(p, this->Mutate(c->rhs));
}

Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
Expand Down
19 changes: 10 additions & 9 deletions src/relay/pass/forward_rewrite.cc
Expand Up @@ -33,10 +33,11 @@ namespace relay {
// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
class TempRealizer : private ExprMutator {
class TempRealizer : private ExprRewriter {
public:
TempRealizer() : ExprRewriter([this](const Expr& expr){this->VisitExpr(expr);}) {}
Expr Realize(Expr expr) {
return VisitExpr(expr);
return Rewrite(expr);
}

private:
Expand All @@ -50,15 +51,15 @@ class TempRealizer : private ExprMutator {
res = temp->Realize();

} else {
res = ExprFunctor::VisitExpr(expr);
res = ExprRewriter::VisitExpr(expr);
}
memo_[res] = res;
memo_[expr] = res;
return res;
}
}
};

class ForwardRewriter : private ExprMutator {
class ForwardRewriter : private ExprRewriter {
public:
ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
std::function<ObjectRef(const Call&)> fcontext,
Expand All @@ -76,7 +77,7 @@ class ForwardRewriter : private ExprMutator {


// Transform expression.
Expr Rewrite(Expr expr) {
Expr Rewrite(const Expr& expr) override {
if (fmulti_ref_trigger_ != nullptr) {
ref_counter_ = GetExprRefCount(expr);
}
Expand All @@ -98,21 +99,21 @@ class ForwardRewriter : private ExprMutator {

Expr VisitExpr(const Expr& expr) final {
// by default always realize.
return realizer_.Realize(ExprMutator::VisitExpr(expr));
return realizer_.Realize(ExprRewriter::Rewrite(expr));
}

// Visit and allow non-realized version.
Expr GetTempExpr(const Expr& expr) {
if (fmulti_ref_trigger_ != nullptr) {
Expr ret = ExprMutator::VisitExpr(expr);
Expr ret = ExprRewriter::VisitExpr(expr);
auto it = ref_counter_.find(expr.get());
CHECK(it != ref_counter_.end());
if (it->second > 1) {
ret = fmulti_ref_trigger_(ret);
}
return ret;
} else {
return ExprMutator::VisitExpr(expr);
return ExprRewriter::VisitExpr(expr);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/util.cc
Expand Up @@ -330,7 +330,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.all_type_vars")
*/
std::unordered_map<const Object*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
class ExprRefCounter : private PostOrderGraphVisitor {
public:
std::unordered_map<const Object*, size_t>
Get(const Expr& body) {
Expand Down

0 comments on commit 9c952ee

Please sign in to comment.