Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] add redirecting operation to dataflow pattern graph #15392

Merged
merged 2 commits into from Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Expand Up @@ -362,6 +362,10 @@ class WildcardPatternNode : public DFPatternNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}

/*! \brief If the wildcard is redirected, then pattern is not nullptr, and the wildcard
* redirects to the pattern. */
Optional<DFPattern> pattern{nullptr};

static constexpr const char* _type_key = "relay.dataflow_pattern.WildcardPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode);
};
Expand All @@ -372,6 +376,8 @@ class WildcardPatternNode : public DFPatternNode {
class WildcardPattern : public DFPattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode);

void redirect_to(DFPattern pat) const;
};

class TypePattern;
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/dataflow_pattern/__init__.py
Expand Up @@ -722,6 +722,19 @@ class WildcardPattern(DFPattern):
def __init__(self):
self.__init_handle_by_constructor__(ffi.WildcardPattern)

def redirect_to(
self,
pat: "DFPattern",
):
"""Redirect the WildcardPattern to another pattern

Parameters
----------
pat: relay.dataflow_pattern.DFPattern
The pattern that wildcard is redirected to.
"""
ffi.WildcardPattern_redirect_to(self, pat)


@register_df_node
class TypePattern(DFPattern):
Expand Down
6 changes: 5 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Expand Up @@ -485,7 +485,11 @@ bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr
}

bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
return true;
if (op->pattern) {
return VisitDFPattern(op->pattern.value(), expr);
} else {
return true;
}
}

bool MatchPattern(DFPattern pattern, Expr expr) {
Expand Down
10 changes: 10 additions & 0 deletions src/relay/ir/dataflow_pattern.cc
Expand Up @@ -344,8 +344,18 @@ TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
<< ")";
});

void WildcardPattern::redirect_to(DFPattern pat) const {
WildcardPatternNode* ptr = static_cast<WildcardPatternNode*>(get_mutable());
ptr->pattern = pat;
}

TVM_REGISTER_NODE_TYPE(WildcardPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern_redirect_to")
.set_body_typed([](WildcardPattern wildcard, DFPattern pat) {
return wildcard.redirect_to(pat);
});

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([]() {
auto w = WildcardPattern(make_object<WildcardPatternNode>());
return w;
Expand Down
6 changes: 5 additions & 1 deletion src/relay/ir/dataflow_pattern_functor.cc
Expand Up @@ -105,7 +105,11 @@ void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {}
void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {
if (op->pattern) {
VisitDFPattern(op->pattern.value());
}
}

} // namespace relay
} // namespace tvm
7 changes: 6 additions & 1 deletion src/relay/ir/indexed_graph.cc
Expand Up @@ -537,7 +537,12 @@ std::unique_ptr<IndexedGraph<DFPattern>> CreateIndexedGraph(const DFPattern& pat

void VisitDFPattern_(const VarPatternNode* op) override {}

void VisitDFPattern_(const WildcardPatternNode* op) override {}
void VisitDFPattern_(const WildcardPatternNode* op) override {
if (op->pattern) {
auto node = graph_->item_to_node(GetRef<WildcardPattern>(op));
AddOutput(op->pattern.value(), node);
}
}

std::unique_ptr<IndexedGraph<DFPattern>> graph_;
};
Expand Down
56 changes: 56 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Expand Up @@ -1964,5 +1964,61 @@ def test_partition_parallel_branch_with_same_input():
assert tvm.ir.structural_equal(partitioned, reference)


def test_rewrite_with_pattern_recursion():
data = relay.var("data", relay.TensorType((2, 8), "float32"))
dense_weight = relay.const(np.zeros((4, 8)))
feat = relay.nn.dense(data, dense_weight)
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
oup = relay.cast(feat, "float32")

expected = relay.nn.relu(oup)

class TheRewrite(DFPatternCallback):
def __init__(self, pattern):
super(TheRewrite, self).__init__(rewrite_once=True)
self.pattern = pattern

def callback(self, pre, post, node_map):
return relay.nn.relu(post)

def test_reset_call_args():
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
wildcard_redirect = wildcard()
the_pattern = is_op("cast")(wildcard_redirect)
the_pattern2 = the_pattern | dense_pattern
wildcard_redirect.redirect_to(the_pattern2)

actual = rewrite(TheRewrite(the_pattern), oup)
tvm.ir.assert_structural_equal(actual, expected)

def test_reset_alt_left():
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
wildcard_redirect = wildcard()
or_pattern = wildcard_redirect | dense_pattern
the_pattern = is_op("cast")(or_pattern)
wildcard_redirect.redirect_to(the_pattern)

actual = rewrite(TheRewrite(the_pattern), oup)
tvm.ir.assert_structural_equal(actual, expected)

def test_reset_alt_right():
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
wildcard_redirect = wildcard()
or_pattern = dense_pattern | wildcard_redirect
the_pattern = is_op("cast")(or_pattern)
wildcard_redirect.redirect_to(the_pattern)

actual = rewrite(TheRewrite(the_pattern), oup)
tvm.ir.assert_structural_equal(actual, expected)

test_reset_call_args()
test_reset_alt_left()
test_reset_alt_right()


if __name__ == "__main__":
tvm.testing.main()