Skip to content

Commit

Permalink
Fix super().visit_var_binding_ in PyExprVisitor and PyExprMutator (#1…
Browse files Browse the repository at this point in the history
…4754)

It will run into an infinite recursion if we call super().visit_var_binding_ in PyExprVisitor and PyExprMutator. This PR fixes the behavior.
  • Loading branch information
jinhongyii committed May 3, 2023
1 parent 107ef2b commit 7a13d1e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/relax/ir/py_expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,13 @@ TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitExpr")

TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding")
.set_body_typed([](PyExprVisitor visitor, const Binding& binding) {
visitor->ExprVisitor::VisitBinding(binding);
if (const auto* ptr = binding.as<VarBindingNode>()) {
visitor->ExprVisitor::VisitBinding_(ptr);
} else if (const auto* ptr = binding.as<MatchCastNode>()) {
visitor->ExprVisitor::VisitBinding_(ptr);
} else {
LOG(FATAL) << "unreachable";
}
});

TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock")
Expand Down Expand Up @@ -604,7 +610,13 @@ TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitExpr")

TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding")
.set_body_typed([](PyExprMutator mutator, const Binding& binding) {
return mutator->ExprMutator::VisitBinding(binding);
if (const auto* ptr = binding.as<VarBindingNode>()) {
return mutator->ExprMutator::VisitBinding_(ptr);
} else if (const auto* ptr = binding.as<MatchCastNode>()) {
return mutator->ExprMutator::VisitBinding_(ptr);
} else {
LOG(FATAL) << "unreachable";
}
});

TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock")
Expand Down
23 changes: 23 additions & 0 deletions tests/python/relax/test_expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from tvm.script import relax as R
import pytest
import tvm.testing

m, n = tir.Var("m", "int64"), tir.Var("n", "int64")
x = relax.Var("x", R.Tensor([n], "float32"))
Expand Down Expand Up @@ -674,13 +675,23 @@ def visit_call_(self, op: Call) -> None:
pass


@R.function
def dummy(x: R.Tensor((10, 10))):
lv = R.add(x, R.const(1))
return lv


def test_call_visitor_super():
@relax.expr_functor.visitor
class InternalVisitor(PyExprVisitor):
def __init__(self) -> None:
super().__init__()
self.log = ASTLog()

def visit_var_binding_(self, binding: relax.VarBinding) -> None:
self.log.add("VarBinding")
super().visit_var_binding_(binding)

def visit_call_(self, op: Call) -> None:
self.log.add("InternalCall")
super().visit_call_(op) # call PyExprVisitor.visit_call_
Expand All @@ -706,6 +717,10 @@ def visit_call_(self, op: Call) -> None:
lv.visit_expr(call_node)
assert str(lv.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"])

lv = LeafVisitor()
lv.visit_expr(dummy)
assert str(lv.log) == "\n".join(["VarBinding", "LeafCall", "InternalCall", "Op", "Var", "Var"])


def test_call_mutator_super():
@relax.expr_functor.mutator
Expand All @@ -714,6 +729,10 @@ def __init__(self) -> None:
super().__init__()
self.log = ASTLog()

def visit_var_binding_(self, binding: relax.VarBinding) -> None:
self.log.add("VarBinding")
super().visit_var_binding_(binding)

def visit_call_(self, op: Call) -> None:
self.log.add("InternalCall")
return super().visit_call_(op) # call PyExprMutator.visit_call_
Expand Down Expand Up @@ -741,6 +760,10 @@ def visit_call_(self, op: Call) -> None:
lm.visit_expr(call_node)
assert str(lm.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"])

lm = LeafMutator()
lm.visit_expr(dummy)
assert str(lm.log) == "\n".join(["VarBinding", "LeafCall", "InternalCall", "Op", "Var", "Var"])


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 7a13d1e

Please sign in to comment.