Skip to content

Commit

Permalink
[PatternLang] Simplify Pattern API Implementations (#5703)
Browse files Browse the repository at this point in the history
* Add syntatic sugar; include pattern to API docs

* fix doc warnings
  • Loading branch information
comaniac committed Jun 2, 2020
1 parent afc239a commit 43162d6
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 113 deletions.
1 change: 1 addition & 0 deletions docs/api/python/index.rst
Expand Up @@ -37,6 +37,7 @@ Python API
relay/transform
relay/analysis
relay/backend
relay/dataflow_pattern
relay/testing
autotvm
rpc
Expand Down
25 changes: 25 additions & 0 deletions docs/api/python/relay/dataflow_pattern.rst
@@ -0,0 +1,25 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
tvm.relay.dataflow_pattern
--------------------------

.. automodule:: tvm.relay.dataflow_pattern
:members:
:imported-members:
:exclude-members: Object, Node
:autosummary:
21 changes: 13 additions & 8 deletions docs/langref/relay_pattern.rst
Expand Up @@ -114,7 +114,7 @@ Since there are not call nodes, we need to use specific pattern nodes to match t
x = relay.var('x')
y = relay.var('y')
z = relay.var('z')
tuple_pattern = TuplePattern((wildcard(), wildcard(), wildcard()))
tuple_pattern = is_tuple((wildcard(), wildcard(), wildcard()))
assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))
The next example is matching a pattern of batch_norm -> get(0) -> relu:
Expand All @@ -123,7 +123,7 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu:
def test_match_tuple_get_item():
bn_node = is_op('nn.batch_norm')(wildcard(), wildcard(), wildcard(), wildcard(), wildcard())
tuple_get_item_node = TupleGetItemPattern(bn_node, 0)
tuple_get_item_node = is_tuple_get_item(bn_node, 0)
pat = is_op('nn.relu')(tuple_get_item_node)
x = relay.var('x', shape=(1, 8))
Expand All @@ -142,7 +142,7 @@ if a specific parameter in a subgraph has been bound or not.
.. code-block:: python
def test_match_constant():
conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
conv2d = is_op('nn.conv2d')(wildcard(), is_constant())
pattern = is_op('nn.bias_add')(conv2d, wildcard())
x = relay.var('x', shape=(1, 3, 224, 224))
Expand All @@ -162,12 +162,12 @@ if a specific parameter in a subgraph has been bound or not.
assert pattern.match(mod['main'].body)
On the other hand, if you need to match the constant with a specific value, you can directly
use ``ExprPattern``. This could be useful for algebraic simplify.
use ``is_expr``. This could be useful for algebraic simplify.

.. code-block:: python
def test_match_plus_zero():
zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0)))
pattern = wildcard() + zero
x = relay.Var('x')
Expand All @@ -193,7 +193,7 @@ The next example is matching a diamond with two inputs at the top of the diamond

def test_match_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
path1 = is_op('nn.relu')(is_conv2d)
path2 = is_op('nn.leaky_relu')(is_conv2d)
diamond = is_op('add')(path1, path2)
Expand All @@ -213,7 +213,7 @@ The final example is matching diamonds with a post-dominator relationship. We em

def test_match_dom_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_elemwise, reduction)

Expand All @@ -240,7 +240,12 @@ The high level design is to introduce a language of patterns for now we propose
| pattern(pattern1, ... patternN)
| has_type(pattern, type)
| has_attr(pattern, attrs)
| is_input(name)
| is_var(name)
| is_constant()
| is_expr(expr)
| is_op(op_name)
| is_tuple()
| is_tuple_get_item()
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)

Expand Down

0 comments on commit 43162d6

Please sign in to comment.