Skip to content

Commit

Permalink
add testcase
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuibin0 committed May 13, 2024
1 parent 08f91b9 commit 26856ad
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# convention.
K_ELEMWISE = 0
K_BROADCAST = 1

K_INJECTIVE = 2

## NODE TESTS
def test_expr_pattern():
Expand Down Expand Up @@ -696,6 +696,28 @@ def test_match_dominator():
assert diamond.match(out)


def test_match_dominator2():
# Pattern
conv2d_pat = is_op("nn.conv2d")(wildcard(), wildcard())
eltwise_pat = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(None)
broadcast_pat = (wildcard().has_attr({"TOpPattern": K_BROADCAST}))(None)
path_pat = eltwise_pat | broadcast_pat
injective_pat = (wildcard().has_attr({"TOpPattern": K_INJECTIVE}))(wildcard())
pattern = injective_pat.dominates(conv2d_pat, path_pat)

# Graph
inp = relay.var("input")
weight = relay.var("weight")
bias = relay.var("bias")
conv2d = relay.op.nn.conv2d(inp, weight)
bias_add = relay.op.nn.bias_add(conv2d, bias)
relu = relay.op.nn.relu(bias_add)
reshape = relay.op.reshape(relu, newshape=[-1, 2, 8])

# Check
assert pattern.match(reshape)


def test_not_match_dominator():
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
Expand Down

0 comments on commit 26856ad

Please sign in to comment.