Skip to content

Commit

Permalink
Partially working, not passing all tests yet
Browse files Browse the repository at this point in the history
passes tests when disabling GetExprRefCount, I think I have a bug in visit counting

fix GetExprRefCount

Fix a subtle bug with nested recursive/non-recursive scopes
  • Loading branch information
Matthew Brookhart committed Mar 18, 2020
1 parent c54feac commit f887049
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 183 deletions.
121 changes: 70 additions & 51 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,91 +234,110 @@ class ExprMutator
};

/*!
* \brief A wrapper around ExprFunctor which traverses the Graph Normal AST.
* \brief a helper class for expanding dataflow regions of the graph non-recursively
*
* PostOrderGraphVisitor treats Expr as dataflow graph.
* PostOrderGraphVisitor provides utitilies for doing a Depth-first
* pre or post order traversal of the graph without recursion.
* It accepts a std::function as part of it's construction which is used
* to do the actual operatiuons on the graph.
* DataflowExpander takes a visit function as an argument and provides a method
* called ExpandDataflow, which will non-recursively call the visit function
* on Dataflow subgraphs in Post-DFS order.
*
* This class is not meant to be used on it's own, users should instead
* use DataflowVisitor and DataflowMutator which wrap this class
*/
class DataflowExpander {
public:
DataflowExpander(std::function<void(const Expr&)> visit, size_t visit_count = 1)
: visit_(visit), visit_count_(visit_count) {}
void ExpandDataflow(const Expr& Expr);
std::unordered_map<const Object*, size_t> GetVisitCounter() { return visit_counter_; }

protected:
/* \brief Function to push a note to the stack if it hasn't been visited
* Also returns false if the node hasn't been visited */
bool PushToStack(const Expr& expr, std::stack<std::pair<Expr, bool>>& stack);
/*! \brief std::function used to process nodes. */
std::function<void(const Expr&)> visit_;
/*! \brief Number of times fo visit each node. */
size_t visit_count_;
// Internal visiting counter
std::unordered_map<const Object*, size_t> visit_counter_;
};

/*!
* \brief A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
*
* DataflowVisitor treats Expr as dataflow graph, and visits in post-DFS order
* It provides an option to run an arbitrary std::function<void(const Expr&)> on
* every node it visits.
*
* DataflowVisitor provides the same recursive API as ExprVisitor, and uses
* recursion to traverse most forms of the IR, but under the hood it uses DataflowExpander
* to non-recursively visit nested dataflow regions of the graph to prevent stack overflows
*/
class PostOrderGraphVisitor : public ::tvm::relay::ExprFunctor<bool(const Expr&)> {
class DataflowVisitor : public ::tvm::relay::ExprVisitor {
public:
PostOrderGraphVisitor() : visitor_([](const Expr&) {}), visit_count_(1) {}
PostOrderGraphVisitor(const std::function<void(const Expr&)>& visitor, int visit_count = 1)
: visitor_(visitor) {
DataflowVisitor() : DataflowVisitor([](const Expr&) {}) {}
DataflowVisitor(const std::function<void(const Expr&)>& visitor, int visit_count = 1)
: visitor_(visitor),
expander_(
[this](const Expr& expr) {
ExprFunctor<void(const Expr&)>::VisitExpr(expr);
visitor_(expr);
this->visit_counter_[expr.get()]++;
},
visit_count) {
CHECK(visit_count > 0) << "GraphMutator visit count must be greater than 0";
CHECK(visit_count < 10) << "GraphMutator visit count must be less than 10";
visit_count_ = visit_count;
}
bool VisitExpr(const Expr& expr) override;
bool VisitExpr_(const VarNode* op) override;
bool VisitExpr_(const ConstantNode* op) override;
bool VisitExpr_(const GlobalVarNode* op) override;
bool VisitExpr_(const OpNode* op) override;
bool VisitExpr_(const TupleNode* op) override;
bool VisitExpr_(const FunctionNode* op) override;
bool VisitExpr_(const CallNode* call_node) override;
bool VisitExpr_(const LetNode* op) override;
bool VisitExpr_(const IfNode* op) override;
bool VisitExpr_(const TupleGetItemNode* op) override;
bool VisitExpr_(const RefCreateNode* op) override;
bool VisitExpr_(const RefReadNode* op) override;
bool VisitExpr_(const RefWriteNode* op) override;
bool VisitExpr_(const ConstructorNode* op) override;
bool VisitExpr_(const MatchNode* op) override;
void VisitExpr(const Expr& expr) override {
if (this->visit_counter_[expr.get()] < visit_count_) {
expander_.ExpandDataflow(expr);
}
}

protected:
/* \brief Function to push a note to the stack if it hasn't been visited
* Also returns false if the node hasn't been visited */
bool PushToStack(const Expr& expr);
/*! \brief std::function used to process nodes. */
std::function<void(const Expr&)> visitor_;
/*! DataflowExpander to non-recursively visit dataflow regions and prevent stack overflows */
DataflowExpander expander_;
/*! \brief Number of times fo visit each node. */
size_t visit_count_;
// Internal visiting counter
std::unordered_map<const Object*, size_t> visit_counter_;
/*! \brief Internal Manually Managed Stack.
* The stack contains a pair of <Expr, bool> where the boolean indicates
* whether or not we've already checked the status of this nodes inputs */
std::stack<std::pair<Expr, bool>> stack_;
};

/*!
* \brief A wrapper around ExprFunctor which functionally updates the AST.
* \brief A wrapper around ExprMutator which functionally updates the AST.
*
* ExprRewriter treats Expr as dataflow graph, and only Rewrites each Expr once.
* DataflowMutator treats Expr as dataflow graph, and only Rewrites each Expr once.
* The mutated results are memoized in a map and reused so that
* local transformation on the dataflow preserves the graph structure.
* ExprRewriter does not iterate over the graph, it assumes another method,
* like PostOrderGraphVisitor, is passing expressions in an appropriate order
*
* DataflowMutator provides the same recursive API as ExprMutator, and uses
* recursion to traverse most forms of the IR, but under the hood it uses DataflowExpander
* to non-recursively visit nested dataflow regions of the graph to prevent stack overflow
*/
class ExprRewriter : protected ::tvm::relay::ExprMutator {
class DataflowMutator : protected ::tvm::relay::ExprMutator {
public:
ExprRewriter(std::function<void(const Expr&)> f) : visitor(PostOrderGraphVisitor(f)) {}
ExprRewriter()
: visitor(PostOrderGraphVisitor([this](const Expr& expr) { ExprMutator::VisitExpr(expr); })) {
}
DataflowMutator()
: expander_(DataflowExpander([this](const Expr& expr) { this->VisitExpr(expr); })) {}
/*!
* \brief Rewrite is alias for VisitExpr
* Rewrite effective does the same thing as ExprMutator's Mutate,
* except it prepends it by a PostOrderGraphTraversal to ensure the inputs have been processed
* \brief Mutate the Expression
* This override of Mutate aeffective does the same thing as ExprMutator's Mutate,
* except it prepends it by a non-recursive Dataflow expansion to prevent stack overflow
* \return expr.
*/
virtual Expr Rewrite(const Expr& expr) {
Expr Mutate(const Expr& expr) final {
if (memo_.count(expr)) {
return memo_[expr];
} else {
this->visitor.VisitExpr(expr);
expander_.ExpandDataflow(expr);
Expr ret = this->VisitExpr(expr);
memo_[expr] = ret;
return ret;
}
}

protected:
Expr Mutate(const Expr& expr) override;
PostOrderGraphVisitor visitor;
DataflowExpander expander_;
};

/*!
Expand Down
4 changes: 2 additions & 2 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,12 @@ TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars")
*/
std::unordered_map<const Object*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private PostOrderGraphVisitor {
class ExprRefCounter : private DataflowVisitor {
public:
std::unordered_map<const Object*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
return expander_.GetVisitCounter();
}
};
return ExprRefCounter().Get(body);
Expand Down
146 changes: 41 additions & 105 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,10 @@
namespace tvm {
namespace relay {

Expr ExprRewriter::Mutate(const Expr& expr) {
CHECK(this->memo_.count(expr)) << "ExprRewriter::Mutate called on a Node with unprocessed inputs";
return this->memo_[expr];
}

bool PostOrderGraphVisitor::PushToStack(const Expr& expr) {
bool DataflowExpander::PushToStack(const Expr& expr, std::stack<std::pair<Expr, bool>>& stack) {
bool out = true;
if (visit_counter_[expr.get()] < visit_count_) {
stack_.push({expr, false});
stack.push({expr, false});
out = false;
} else {
visit_counter_[expr.get()]++;
Expand All @@ -50,109 +45,50 @@ bool PostOrderGraphVisitor::PushToStack(const Expr& expr) {
return out;
}

bool PostOrderGraphVisitor::VisitExpr(const Expr& expr) {
PushToStack(expr);
while (stack_.size() > 0) {
std::pair<Expr, bool>& current = stack_.top();
if (visit_counter_[current.first.get()] < visit_count_) {
if (current.second // If we have already checked this node
|| current.first.as<TempExprNode>() // Or it's temporary
|| ExprFunctor::VisitExpr(current.first)) { // or we've already visited it's inputs
// Do post order visitation
visitor_(current.first);
visit_counter_[current.first.get()]++;
stack_.pop();
} else {
// Otherwise. the VisitExpr function just pushed univisted children onto the stack
// The next time we see this item on the stack, it's children must have been visted
void DataflowExpander::ExpandDataflow(const Expr& expr) {
/*! \brief Internal Manually Managed Stack.
* The stack contains a pair of <const ExprNoe*, bool> where the boolean indicates
* whether or not we've already checked the status of this nodes inputs
* The stack is initialized locally in the ExpandDataflow function to prevent non-recursive
* behavior and recursive behavior interacting.
* */
std::stack<std::pair<Expr, bool>> stack;
PushToStack(expr, stack);
while (stack.size() > 0) {
std::pair<Expr, bool>& current = stack.top();
Expr node = current.first;
if (visit_counter_[node.get()] < visit_count_) {
bool visit_now = true;
if (!current.second) {
if (const CallNode* op = node.as<CallNode>()) {
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
visit_now &= PushToStack(*it, stack);
}
visit_now &= PushToStack(op->op, stack);
} else if (const TupleNode* op = node.as<TupleNode>()) {
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
visit_now &= PushToStack(*it, stack);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
visit_now &= PushToStack(op->tuple, stack);
}
current.second = true;
}
if (visit_now) {
// Do post order visitation
visit_(node);
visit_counter_[node.get()]++;
stack.pop();
}
} else {
visit_counter_[current.first.get()]++;
stack_.pop();
visit_counter_[node.get()]++;
stack.pop();
}
}
// This is just to match the template API of ExprFunctor
return true;
}

bool PostOrderGraphVisitor::VisitExpr_(const VarNode* op) { return true; }
bool PostOrderGraphVisitor::VisitExpr_(const ConstantNode* op) { return true; }
bool PostOrderGraphVisitor::VisitExpr_(const GlobalVarNode* op) { return true; }
bool PostOrderGraphVisitor::VisitExpr_(const OpNode* op) { return true; }
bool PostOrderGraphVisitor::VisitExpr_(const TupleNode* op) {
bool children_processed = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
children_processed &= PushToStack(*it);
}
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const FunctionNode* op) {
bool children_processed = true;
// push the children to the stack in reverse order
// to match recursive processing order
children_processed &= PushToStack(op->body);
for (auto it = op->params.rbegin(); it != op->params.rend(); ++it) {
children_processed &= PushToStack(*it);
}
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const CallNode* call_node) {
bool children_processed = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = call_node->args.rbegin(); it != call_node->args.rend(); ++it) {
children_processed &= PushToStack(*it);
}
children_processed &= PushToStack(call_node->op);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const LetNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->body);
children_processed &= PushToStack(op->value);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const IfNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->false_branch);
children_processed &= PushToStack(op->true_branch);
children_processed &= PushToStack(op->cond);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const TupleGetItemNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->tuple);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const RefCreateNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->value);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const RefReadNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->ref);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const RefWriteNode* op) {
bool children_processed = true;
children_processed &= PushToStack(op->value);
children_processed &= PushToStack(op->ref);
return children_processed;
}
bool PostOrderGraphVisitor::VisitExpr_(const ConstructorNode* op) { return true; }
bool PostOrderGraphVisitor::VisitExpr_(const MatchNode* op) {
bool children_processed = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->clauses.rbegin(); it != op->clauses.rend(); ++it) {
children_processed &= PushToStack((*it)->rhs);
}
children_processed &= PushToStack(op->data);
return children_processed;
}

Expr ExprMutator::VisitExpr(const Expr& expr) {
Expand Down
16 changes: 6 additions & 10 deletions src/relay/transforms/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class Eliminator : private ExprMutator {
};

// calculate the dependency graph from expression
class CalcDep : private PostOrderGraphVisitor {
class CalcDep : private DataflowVisitor {
public:
static Expr Eliminate(const Expr& e, bool inline_once) {
FindDef fd;
Expand All @@ -105,25 +105,21 @@ class CalcDep : private PostOrderGraphVisitor {

private:
explicit CalcDep(const VarMap<Expr>& expr_map)
: PostOrderGraphVisitor([](const Expr&) {}, 2), expr_map_(expr_map) {}
: DataflowVisitor([](const Expr&) {}, 2), expr_map_(expr_map) {}
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;


bool VisitExpr_(const LetNode* l) final {
bool children_processed = true;
children_processed &= PushToStack(l->body);
return children_processed;
void VisitExpr_(const LetNode* l) final {
VisitExpr(l->body);
}

bool VisitExpr_(const VarNode* v) final {
bool children_processed = true;
void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
++use_map_[var];
if (use_map_[var] == 1 && expr_map_.count(var) > 0) {
children_processed &= PushToStack(expr_map_[var]);
VisitExpr(expr_map_[var]);
}
return children_processed;
}
};

Expand Down
Loading

0 comments on commit f887049

Please sign in to comment.