Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Mar 18, 2020
1 parent f887049 commit 8871e0b
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 128 deletions.
85 changes: 8 additions & 77 deletions include/tvm/relay/expr_functor.h
Expand Up @@ -32,7 +32,6 @@
#include <tvm/relay/adt.h>
#include <tvm/relay/op.h>

#include <stack>
#include <string>
#include <utility>
#include <unordered_map>
Expand Down Expand Up @@ -233,35 +232,6 @@ class ExprMutator
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_;
};

/*!
* \brief a helper class for expanding dataflow regions of the graph non-recursively
*
* DataflowExpander takes a visit function as an argument and provides a method
* called ExpandDataflow, which will non-recursively call the visit function
* on Dataflow subgraphs in Post-DFS order.
*
* This class is not meant to be used on it's own, users should instead
* use DataflowVisitor and DataflowMutator which wrap this class
*/
class DataflowExpander {
public:
DataflowExpander(std::function<void(const Expr&)> visit, size_t visit_count = 1)
: visit_(visit), visit_count_(visit_count) {}
void ExpandDataflow(const Expr& Expr);
std::unordered_map<const Object*, size_t> GetVisitCounter() { return visit_counter_; }

protected:
/* \brief Function to push a note to the stack if it hasn't been visited
* Also returns false if the node hasn't been visited */
bool PushToStack(const Expr& expr, std::stack<std::pair<Expr, bool>>& stack);
/*! \brief std::function used to process nodes. */
std::function<void(const Expr&)> visit_;
/*! \brief Number of times fo visit each node. */
size_t visit_count_;
// Internal visiting counter
std::unordered_map<const Object*, size_t> visit_counter_;
};

/*!
* \brief A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
*
Expand All @@ -275,33 +245,13 @@ class DataflowExpander {
*/
class DataflowVisitor : public ::tvm::relay::ExprVisitor {
public:
DataflowVisitor() : DataflowVisitor([](const Expr&) {}) {}
DataflowVisitor(const std::function<void(const Expr&)>& visitor, int visit_count = 1)
: visitor_(visitor),
expander_(
[this](const Expr& expr) {
ExprFunctor<void(const Expr&)>::VisitExpr(expr);
visitor_(expr);
this->visit_counter_[expr.get()]++;
},
visit_count) {
CHECK(visit_count > 0) << "GraphMutator visit count must be greater than 0";
CHECK(visit_count < 10) << "GraphMutator visit count must be less than 10";
visit_count_ = visit_count;
}
void VisitExpr(const Expr& expr) override {
if (this->visit_counter_[expr.get()] < visit_count_) {
expander_.ExpandDataflow(expr);
}
}
DataflowVisitor(int visit_limit = 1);
void VisitExpr(const Expr& expr) final;

protected:
/*! \brief std::function used to process nodes. */
std::function<void(const Expr&)> visitor_;
/*! DataflowExpander to non-recursively visit dataflow regions and prevent stack overflows */
DataflowExpander expander_;
/*! \brief Number of times fo visit each node. */
size_t visit_count_;
virtual void VisitLeaf(const Expr& expr);
virtual bool CheckVisited(const Expr& expr);
size_t visit_limit_;
};

/*!
Expand All @@ -316,30 +266,11 @@ class DataflowVisitor : public ::tvm::relay::ExprVisitor {
* to non-recursively visit nested dataflow regions of the graph to prevent stack overflow
*/
class DataflowMutator : protected ::tvm::relay::ExprMutator {
public:
DataflowMutator()
: expander_(DataflowExpander([this](const Expr& expr) { this->VisitExpr(expr); })) {}
/*!
* \brief Mutate the Expression
* This override of Mutate aeffective does the same thing as ExprMutator's Mutate,
* except it prepends it by a non-recursive Dataflow expansion to prevent stack overflow
* \return expr.
*/
Expr Mutate(const Expr& expr) final {
if (memo_.count(expr)) {
return memo_[expr];
} else {
expander_.ExpandDataflow(expr);
Expr ret = this->VisitExpr(expr);
memo_[expr] = ret;
return ret;
}
}

protected:
DataflowExpander expander_;
virtual void VisitLeaf(const Expr& expr);
virtual bool CheckVisited(const Expr& expr);
Expr Mutate(const Expr& expr) final;
};

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
Expand Down
5 changes: 4 additions & 1 deletion src/relay/analysis/util.cc
Expand Up @@ -335,7 +335,10 @@ GetExprRefCount(const Expr& body) {
std::unordered_map<const Object*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return expander_.GetVisitCounter();
for (auto kv : visit_counter_) {
std::cout << kv.first << '\t' << kv.second << std::endl;
}
return std::move(visit_counter_);
}
};
return ExprRefCounter().Get(body);
Expand Down
149 changes: 100 additions & 49 deletions src/relay/ir/expr_functor.cc
Expand Up @@ -29,68 +29,119 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>

#include <stack>

namespace tvm {
namespace relay {

bool DataflowExpander::PushToStack(const Expr& expr, std::stack<std::pair<Expr, bool>>& stack) {
bool out = true;
if (visit_counter_[expr.get()] < visit_count_) {
stack.push({expr, false});
out = false;
} else {
visit_counter_[expr.get()]++;
}
// return true if this node was already visited,
// or false if we had to push it onto the stack
return out;
}

void DataflowExpander::ExpandDataflow(const Expr& expr) {
/*! \brief Internal Manually Managed Stack.
* The stack contains a pair of <const ExprNoe*, bool> where the boolean indicates
* whether or not we've already checked the status of this nodes inputs
* The stack is initialized locally in the ExpandDataflow function to prevent non-recursive
* behavior and recursive behavior interacting.
* */
template <typename FCheckVisited, typename FVisitLeaf>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
std::stack<std::pair<Expr, bool>> stack;
PushToStack(expr, stack);
// The second state of the stack indicate whether the child has been
// expanded in the pre-order.
// NOTE: function will be inlined.
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
if (!fcheck_visited(expr)) {
stack.push({expr, false});
}
};
fpush_to_stack(expr);
while (stack.size() > 0) {
std::pair<Expr, bool>& current = stack.top();
Expr node = current.first;
if (visit_counter_[node.get()] < visit_count_) {
bool visit_now = true;
if (!current.second) {
if (const CallNode* op = node.as<CallNode>()) {
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
visit_now &= PushToStack(*it, stack);
}
visit_now &= PushToStack(op->op, stack);
} else if (const TupleNode* op = node.as<TupleNode>()) {
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
visit_now &= PushToStack(*it, stack);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
visit_now &= PushToStack(op->tuple, stack);
}
current.second = true;
auto node = stack.top().first;
// if this node was visited through another path
// after being added to the stack ignore it.
if (fcheck_visited(expr)) {
stack.pop();
} else if (stack.top().second) {
// all the children has already been expanded.
// we can just run post order visit on it.
fvisit_leaf(node);
stack.pop();
} else if (const CallNode* op = node.as<CallNode>()) {
// mark expanded = true
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
fpush_to_stack(*it);
}
if (visit_now) {
// Do post order visitation
visit_(node);
visit_counter_[node.get()]++;
stack.pop();
fpush_to_stack(op->op);
} else if (const TupleNode* op = node.as<TupleNode>()) {
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
fpush_to_stack(*it);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
stack.top().second = true;
fpush_to_stack(op->tuple);
} else {
visit_counter_[node.get()]++;
// No need to expand the children directly run visit.
// terminal leaf, directly use visited.
fvisit_leaf(node);
stack.pop();
}
}
}

DataflowVisitor::DataflowVisitor(int visit_limit) {
CHECK(visit_limit > 0) << "Dataflow visit limit must be greater than 0";
CHECK(visit_limit < 10) << "Dataflow visit limit must be less than 10";
visit_limit_ = visit_limit;
}

void DataflowVisitor::VisitLeaf(const Expr& expr) {
if (visit_counter_[expr.get()] == 0) {
ExprFunctor::VisitExpr(expr);
}
visit_counter_[expr.get()]++;
}

bool DataflowVisitor::CheckVisited(const Expr& expr) {
if (visit_counter_[expr.get()] < visit_limit_) {
return false;
} else {
visit_counter_[expr.get()]++;
return true;
}
}

void DataflowVisitor::VisitExpr(const Expr& expr) {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
if (visit_counter_[expr.get()] < 1) {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
}
}

void DataflowMutator::VisitLeaf(const Expr& expr) {
if (!memo_.count(expr)) {
this->VisitExpr(expr);
}
}

bool DataflowMutator::CheckVisited(const Expr& expr) {
if (memo_.count(expr)) {
return true;
} else {
return false;
}
}

Expr DataflowMutator::Mutate(const Expr& expr) {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
if (memo_.count(expr)) {
return memo_[expr];
} else {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
Expr ret = this->VisitExpr(expr);
memo_[expr] = ret;
return ret;
}
}

Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr);
if (it != this->memo_.end()) {
Expand Down
12 changes: 11 additions & 1 deletion src/relay/transforms/dead_code.cc
Expand Up @@ -105,10 +105,20 @@ class CalcDep : private DataflowVisitor {

private:
explicit CalcDep(const VarMap<Expr>& expr_map)
: DataflowVisitor([](const Expr&) {}, 2), expr_map_(expr_map) {}
: DataflowVisitor(2), expr_map_(expr_map) {}
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;

void VisitLeaf(const Expr& e) final {
visit_counter_[e.get()]++;
// The dce code seprate variable into three parts:
// used 0 times (remove)
// used 1 times (inline)
// used 2 times (dont do anything).
if (visit_counter_[e.get()] <= 2) {
ExprFunctor::VisitExpr(e);
}
}

void VisitExpr_(const LetNode* l) final {
VisitExpr(l->body);
Expand Down

0 comments on commit 8871e0b

Please sign in to comment.