Skip to content

Commit

Permalink
Fix a problem with default recursion for dataflow nodes
Browse files Browse the repository at this point in the history
mark DataflowVisitor methods as override
  • Loading branch information
Matthew Brookhart committed Mar 23, 2020
1 parent fde6f2b commit 62cddf2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
4 changes: 4 additions & 0 deletions include/tvm/relay/expr_functor.h
Expand Up @@ -249,6 +249,10 @@ class DataflowVisitor : public ::tvm::relay::ExprVisitor {
* \brief VisitExpr is finalized to preserve call expansion of dataflow regions
*/
void VisitExpr(const Expr& expr) final;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;


protected:
/*!
Expand Down
13 changes: 11 additions & 2 deletions src/relay/ir/expr_functor.cc
Expand Up @@ -108,7 +108,7 @@ DataflowVisitor::DataflowVisitor(int visit_limit) {
}

void DataflowVisitor::VisitLeaf(const Expr& expr) {
if (visit_counter_[expr.get()] == 0) {
if (visit_counter_[expr.get()] < visit_limit_) {
ExprFunctor::VisitExpr(expr);
}
visit_counter_[expr.get()]++;
Expand All @@ -126,11 +126,20 @@ bool DataflowVisitor::CheckVisited(const Expr& expr) {
void DataflowVisitor::VisitExpr(const Expr& expr) {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
if (visit_counter_[expr.get()] < 1) {
if (visit_counter_[expr.get()] < visit_limit_) {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
}
}

// Overwrite the VisitExpr so we don't recurse for dataflow nodes
void DataflowVisitor::VisitExpr_(const CallNode* op) {}

// Overwrite the VisitExpr so we don't recurse for dataflow nodes
void DataflowVisitor::VisitExpr_(const TupleNode* op) {}

// Overwrite the VisitExpr so we don't recurse for dataflow nodes
void DataflowVisitor::VisitExpr_(const TupleGetItemNode* op) {}

void DataflowMutator::VisitLeaf(const Expr& expr) {
if (!memo_.count(expr)) {
this->VisitExpr(expr);
Expand Down
3 changes: 2 additions & 1 deletion src/relay/transforms/dead_code.cc
Expand Up @@ -116,7 +116,8 @@ class CalcDep : private DataflowVisitor {
// used 1 times (inline)
// used 2 times (dont do anything).
if (visit_counter_[e.get()] <= 2) {
ExprFunctor::VisitExpr(e);
using TParent = ExprFunctor<void(const Expr&)>;
TParent::VisitExpr(e);
}
}

Expand Down

0 comments on commit 62cddf2

Please sign in to comment.