Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PatternMatcher] Support matching tuples, call nodes, and functions with variable numbers of inputs #7754

Merged
merged 4 commits into from
Apr 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,22 @@ The final example is matching diamonds with a post-dominator relationship. We em
assert diamond.match(out)


Matching Fuzzy Patterns
=======================

The Dominator analysis above lets one match a subgraph of Relay AST that doesn't correspond to a set of patterns nodes exactly 1-to-1. There are a few other places where we support such "fuzzy" matching.

Tuples, Functions, and Call nodes with any number of inputs can be matched by passing `None` as the argument value, i.e.::

tuple_pattern = is_tuple(None)
func_pattern = FunctionPattern(None, wildcard() + wildcard())
call_pattern = func_pattern(None)

These patterns allow matching more generic classes patterns by constraining the use of the arguments rather than the number of arguments.

Additionally, we support matching Functions with fuzzy bodies, i.e., a function body that is under constrained by the pattern. The pattern `FunctionPattern([is_var(), is_var()], wildcard() + wildcard()])` will match `relay.Function([x, y], x + y)`, but it will also match `relay.Function([x, y], x * x + y)`. In the second case, the pattern doesn't perfectly constrain the body of the function, so the resulting match is fuzzy.


Pattern Language Design
=======================

Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ class DFPattern(Node):
"""Base class of all Patterns."""

def __call__(self, *args):
return CallPattern(self, list(args))
args = list(args)
if len(args) == 1 and args[0] is None:
args = None
return CallPattern(self, args)

def __or__(self, other):
return AltPattern(self, other)
Expand Down
107 changes: 73 additions & 34 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
}
return false;
};

// logic
auto watermark = matched_nodes_.size();
if (const auto* call_node = expr.as<CallNode>()) {
Expand All @@ -253,13 +254,15 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
const Array<Expr> expr_args) {
bool matches = true;
size_t i = 0;
if (pattern_args.size() == expr_args.size()) {
while (matches && i < pattern_args.size()) {
matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
++i;
if (pattern_args.defined()) {
if (pattern_args.size() == expr_args.size()) {
while (matches && i < pattern_args.size()) {
matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
++i;
}
} else {
matches = false;
}
} else {
matches = false;
}
if (!matches) {
ClearMap(watermark2);
Expand Down Expand Up @@ -381,14 +384,16 @@ bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr
bool matches = false;
if (const auto* func = expr.as<FunctionNode>()) {
matches = true;
size_t i = 0;
if (op->params.size() == func->params.size()) {
while (matches && i < op->params.size()) {
matches &= VisitDFPattern(op->params[i], func->params[i]);
++i;
if (op->params.defined()) {
size_t i = 0;
if (op->params.size() == func->params.size()) {
while (matches && i < op->params.size()) {
matches &= VisitDFPattern(op->params[i], func->params[i]);
++i;
}
} else {
matches = false;
}
} else {
matches = false;
}
if (matches) {
matches &= VisitDFPattern(op->body, func->body);
Expand All @@ -409,12 +414,16 @@ bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const
bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* tuple_node = expr.as<TupleNode>()) {
if (op->fields.size() == tuple_node->fields.size()) {
matches = true;
size_t i = 0;
while (matches && i < op->fields.size()) {
matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
++i;
matches = true;
if (op->fields.defined()) {
if (op->fields.size() == tuple_node->fields.size()) {
size_t i = 0;
while (matches && i < op->fields.size()) {
matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
++i;
}
} else {
matches = false;
}
}
}
Expand Down Expand Up @@ -657,7 +666,6 @@ class PatternGrouper {
int var_number = 0;

auto node_map = matcher_->GetMemo();

// Get fuzzy patterns
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
for (auto node : pattern_graph_.topological_order_) {
Expand All @@ -669,11 +677,13 @@ class PatternGrouper {
}
}
}
// Don't treat Function params as input variables for partition
if (auto op = node->ref_.as<FunctionPatternNode>()) {
for (auto fuzzy_op : op->params) {
for (auto match : node_map[fuzzy_op]) {
fuzzy_matches.insert(match);
// Don't treat Function params or body as input variables for partition
if (node->ref_.as<FunctionPatternNode>()) {
auto matches = node_map[node->ref_];
for (auto match : matches) {
auto graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
for (auto node : graph.topological_order_) {
fuzzy_matches.insert(node->ref_);
}
}
}
Expand All @@ -686,22 +696,46 @@ class PatternGrouper {

std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
Array<Var> params;

for (auto node : pattern_graph_.topological_order_) {
if (node->inputs_.size() == 0) {
auto make_input = [&](const Expr& input) {
if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
input.as<FunctionNode>() == nullptr && !EmbedConst(input, node->ref_)) {
inputs[input] =
Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
NullValue<Type>());
group.args.push_back(input);
params.push_back(inputs[input]);
var_number++;
}
};
auto tuple = node->ref_.as<TuplePatternNode>();
auto call = node->ref_.as<CallPatternNode>();
if (tuple && !tuple->fields.defined()) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
for (auto match : matches) {
if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
match.as<FunctionNode>() == nullptr && !EmbedConst(match, node->ref_)) {
inputs[match] = Var(
"FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
NullValue<Type>());
group.args.push_back(match);
params.push_back(inputs[match]);
var_number++;
for (auto input : match.as<TupleNode>()->fields) {
make_input(input);
}
}
}
} else if (call && !call->args.defined()) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
for (auto match : matches) {
for (auto input : match.as<CallNode>()->args) {
make_input(input);
}
}
}
} else if (node->inputs_.size() == 0) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
for (auto match : matches) {
make_input(match);
}
}
}
}

Expand Down Expand Up @@ -898,6 +932,11 @@ class PatternPartitioner : protected MixedModeMutator {
public:
Expr Partition(const DFPattern& pattern, const Expr& pre, const Map<String, ObjectRef>& attrs,
PackedFunc check) {
if (pattern.as<FunctionPatternNode>()) {
LOG(WARNING) << "Partioning a Function that isn't called doesn't make sense, skipping"
<< pattern;
return pre;
}
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(pattern, pre);
gid_assignments_ = grouper.GetGIDAssignments();
Expand Down
18 changes: 12 additions & 6 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ void DFPatternVisitor::VisitDFPattern_(const AttrPatternNode* op) { VisitDFPatte

void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) {
VisitDFPattern(op->op);
for (auto arg : op->args) {
VisitDFPattern(arg);
if (op->args.defined()) {
for (auto arg : op->args) {
VisitDFPattern(arg);
}
}
}

Expand All @@ -63,8 +65,10 @@ void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) {
void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) {
for (auto param : op->params) {
VisitDFPattern(param);
if (op->params.defined()) {
for (auto param : op->params) {
VisitDFPattern(param);
}
}
VisitDFPattern(op->body);
}
Expand All @@ -76,8 +80,10 @@ void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) {
}

void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) {
for (auto field : op->fields) {
VisitDFPattern(field);
if (op->fields.defined()) {
for (auto field : op->fields) {
VisitDFPattern(field);
}
}
}

Expand Down
18 changes: 12 additions & 6 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {

void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override {
VisitDFPattern(op->op, graph_.node_map_[GetRef<DFPattern>(op)]);
for (auto arg : op->args) {
VisitDFPattern(arg, graph_.node_map_[GetRef<DFPattern>(op)]);
if (op->args.defined()) {
for (auto arg : op->args) {
VisitDFPattern(arg, graph_.node_map_[GetRef<DFPattern>(op)]);
}
}
}

Expand All @@ -262,8 +264,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {}

void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) override {
for (auto param : op->params) {
VisitDFPattern(param, graph_.node_map_[GetRef<DFPattern>(op)]);
if (op->params.defined()) {
for (auto param : op->params) {
VisitDFPattern(param, graph_.node_map_[GetRef<DFPattern>(op)]);
}
}
VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
}
Expand All @@ -277,8 +281,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
}

void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override {
for (auto field : op->fields) {
VisitDFPattern(field, graph_.node_map_[GetRef<DFPattern>(op)]);
if (op->fields.defined()) {
for (auto field : op->fields) {
VisitDFPattern(field, graph_.node_map_[GetRef<DFPattern>(op)]);
}
}
}

Expand Down
Loading