# 数据流图模式

In [1]:
import numpy as np

import tvm
from tvm import relay
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *
from tvm.relay.testing import run_opt_pass

```{tip}
`1` 对应于指定此值的 C++ enum，由于 Python/C++ 调用约定，失去了类型安全。
```

In [2]:
K_ELEMWISE = 0
K_BROADCAST = 1

## 节点

In [5]:
def test_expr_pattern():
    ep = is_expr(relay.var("x", shape=(4, 1)))
    assert isinstance(ep, ExprPattern)
    assert isinstance(ep.expr, relay.Var)


def test_var_pattern():
    v = is_var("x")
    assert isinstance(v, VarPattern)
    assert v.name == "x"


def test_constant_pattern():
    c = is_constant()
    assert isinstance(c, ConstantPattern)


def test_wildcard_pattern():
    wc = wildcard()
    assert isinstance(wc, WildcardPattern)


def test_CallPattern():
    wc1 = wildcard()
    wc2 = wildcard()
    c = is_op("add")(wc1, wc2)
    assert isinstance(c, CallPattern)
    assert isinstance(c.args[0], WildcardPattern)
    assert isinstance(c.args[1], WildcardPattern)


def test_FunctionPattern():
    wc1 = wildcard()
    wc2 = wildcard()
    c = is_op("add")(wc1, wc2)
    f = FunctionPattern([wc1, wc2], c)
    assert isinstance(f, FunctionPattern)
    assert isinstance(f.params[0], WildcardPattern)
    assert isinstance(f.params[1], WildcardPattern)
    assert isinstance(f.body, CallPattern)
    assert isinstance(f.body.args[0], WildcardPattern)
    assert isinstance(f.body.args[1], WildcardPattern)


def test_TuplePattern():
    wc1 = wildcard()
    wc2 = wildcard()
    t = is_tuple([wc1, wc2])
    assert isinstance(t, TuplePattern)
    assert isinstance(t.fields[0], WildcardPattern)
    assert isinstance(t.fields[1], WildcardPattern)


def test_TupleGetItemPattern():
    wc1 = wildcard()
    wc2 = wildcard()
    t = is_tuple([wc1, wc2])
    tgi = is_tuple_get_item(t, 1)
    assert isinstance(tgi, TupleGetItemPattern)
    assert isinstance(tgi.tuple, TuplePattern)
    assert isinstance(tgi.tuple.fields[0], WildcardPattern)
    assert isinstance(tgi.tuple.fields[1], WildcardPattern)


def test_AltPattern():
    is_add_or_sub = is_op("add") | is_op("subtract")
    assert isinstance(is_add_or_sub, AltPattern)


def test_TypePattern():
    ttype = relay.TensorType((10, 10), "float32")
    ty_pat = has_type(ttype)
    assert isinstance(ty_pat, TypePattern)
    assert ty_pat.type == ttype


def test_DataTypePattern():
    dtype = "float16"
    pattern = has_dtype(dtype)
    assert isinstance(pattern, DataTypePattern)
    assert pattern.dtype == dtype


def test_ShapePattern():
    shape = [10, 10]
    pattern = has_shape(shape)
    assert isinstance(pattern, ShapePattern)
    assert tvm.ir.structural_equal(pattern.shape, shape)


def test_AttrPattern():
    op = is_op("add").has_attr({"TOpPattern": K_ELEMWISE})
    assert isinstance(op, AttrPattern)
    assert op.attrs["TOpPattern"] == K_ELEMWISE


def test_IfPattern():
    x = is_var("x")
    y = is_var("y")
    pat = is_if(is_op("less")(x, y), x, y)

    assert isinstance(pat, IfPattern)
    assert isinstance(pat.cond, CallPattern)
    assert isinstance(pat.true_branch, VarPattern)
    assert isinstance(pat.false_branch, VarPattern)


def test_LetPattern():
    x = is_var("x")
    y = is_var("y")
    let_var = is_var("let")
    pat = is_let(let_var, is_op("less")(x, y), let_var)

    assert isinstance(pat, LetPattern)
    assert isinstance(pat.var, VarPattern)
    assert isinstance(pat.value, CallPattern)
    assert isinstance(pat.body, VarPattern)


## matcher

In [8]:
assert is_op("add").match(relay.op.op.get("add"))

assert not is_op("add").match(relay.op.op.get("subtract"))

In [9]:
is_add_or_sub = is_op("add") | is_op("subtract")
assert is_add_or_sub.match(relay.op.op.get("add"))
assert is_add_or_sub.match(relay.op.op.get("subtract"))

`call_commutive`:

In [10]:
x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("add")(is_var("x"), is_var("y"))
assert add_pattern.match(x + y)
assert add_pattern.match(y + x)
mul_pattern = is_op("multiply")(is_var("x"), is_var("y"))
assert mul_pattern.match(x * y)
assert mul_pattern.match(y * x)

In [11]:
x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("subtract")(is_var("x"), is_var("y"))
assert add_pattern.match(x - y)
assert not add_pattern.match(y - x)
add_pattern = is_op("divide")(is_var("x"), is_var("y"))
assert add_pattern.match(x / y)
assert not add_pattern.match(y / x)

`call`:

In [12]:
x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
assert add_pattern.match(x + y)

# Match call with any number of inputs
call_pattern = wildcard()(None)
assert call_pattern.match(relay.op.nn.relu(x))
assert call_pattern.match(relay.op.add(x, y))

In [13]:
x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
assert not add_pattern.match(x - y)

`func`:

In [14]:
x = relay.var("x")
y = relay.var("y")
wc1 = wildcard()
wc2 = wildcard()
func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
assert func_pattern.match(relay.Function([x, y], x + y))

# Match Function with any number of inputs
func_pattern = FunctionPattern(None, wildcard())
assert func_pattern.match(relay.Function([x], x))
assert func_pattern.match(relay.Function([x, y], x + y))

In [15]:
x = relay.var("x")
y = relay.var("y")
wc1 = wildcard()
wc2 = wildcard()
func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
assert not func_pattern.match(relay.Function([x, y], x - y))

`if`:

In [16]:
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

x = relay.var("x")
y = relay.var("y")
cond = x < y

assert pat.match(relay.expr.If(cond, x, y))

In [17]:
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

x = relay.var("x")
y = relay.var("y")

assert not pat.match(relay.expr.If(x > y, x, y))
assert not pat.match(relay.expr.If(x < y, y, x))

`let`:

In [18]:
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")
cond = x < y
assert pat.match(relay.expr.Let(lv, cond, lv))

In [19]:
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")

assert not pat.match(relay.expr.Let(lv, x > y, lv))
assert not pat.match(relay.expr.Let(lv, x < y, lv * x))

`option`:

In [20]:
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")
pattern = is_op("nn.relu")(
    is_op("nn.conv2d")(wildcard(), wildcard()).optional(
        lambda x: is_op("nn.bias_add")(x, wildcard())
    )
)

conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
assert pattern.match(relu)

conv2d = relay.op.nn.conv2d(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
assert pattern.match(relu)

pattern = is_op("nn.conv2d")(wildcard(), wildcard())
pattern = pattern.optional(is_op("nn.relu")).optional(is_op("tanh"))

conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
tanh = relay.op.tanh(conv2d)
tanh2 = relay.op.tanh(relu)
relu2 = relay.op.nn.relu(tanh)
assert pattern.match(conv2d)
assert pattern.match(relu)
assert pattern.match(tanh)
assert pattern.match(tanh2)
assert not pattern.match(relu2)

In [21]:
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")
pattern = is_op("nn.relu")(
    is_op("nn.conv2d")(wildcard(), wildcard()).optional(
        lambda x: is_op("nn.bias_add")(x, wildcard())
    )
)

conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.tanh(conv2d)
assert not pattern.match(relu)

conv2d = relay.op.nn.dense(x, w)
relu = relay.op.tanh(conv2d)
assert not pattern.match(relu)

conv2d = relay.op.nn.dense(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
assert not pattern.match(relu)

conv2d = relay.op.nn.conv2d(x, w)
bias_add = conv2d + w
relu = relay.op.nn.relu(bias_add)
assert not pattern.match(relu)

## 支配节点

In [28]:
# Pattern
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op("add")(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

# Classic Diamond
inp = relay.var("input")
weight = relay.var("weight")
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu

# Check
assert diamond.match(out)

## 重写

In [79]:
from tvm.ir import IRModule
from tvm.relay import Function

### 替换运算

In [32]:
x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
sub_pattern = is_op("subtract")(wildcard(), wildcard())

class TestRewrite(DFPatternCallback):
    def __init__(self):
        super().__init__()
        self.pattern = add_pattern

    def callback(self, pre, post, node_map):
        return post.args[0] - post.args[1]

z = x + y
out = rewrite(TestRewrite(), z)
assert sub_pattern.match(out)

In [75]:
def rewrite_op(x, y, out):
    func = Function([x, y], out)
    mod = IRModule.from_expr(func)
    # executor = relay.create_executor("vm", mod)
    executor = relay.create_executor("graph", mod)
    return executor.evaluate()

x = relay.var("x", shape=[1], dtype="float32")
y = relay.var("y", shape=[1], dtype="float32")
z = x + y
out = rewrite(TestRewrite(), z)
evaluate = rewrite_op(x, y, z)
rewrite_evaluate = rewrite_op(x, y, out)
a = np.array([2], dtype="float32")
b = np.array([4], dtype="float32")
evaluate(a, b).numpy(), rewrite_evaluate(a, b).numpy()

### 重写函数

In [80]:
x = relay.var("x")
w = relay.var("w")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
sub_pattern = is_op("subtract")(wildcard(), wildcard())

class TestRewrite(DFPatternCallback):
    def __init__(self):
        super(TestRewrite, self).__init__()
        self.pattern = add_pattern

    def callback(self, pre, post, node_map):
        return post.args[0] - post.args[1]

inpf = relay.var("input")
weightf = relay.var("weight")
func = relay.Function(
    [inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None
)
out = rewrite(TestRewrite(), func(x, w) + y)
assert sub_pattern.match(out)

In [95]:
data = relay.var("data", shape=(1, 4, 2, 2), dtype="float32")
stride = 2 # stride = 1 对重组运算没有意义。
func = relay.op.vision.yolo.yolo_reorg(data, stride=stride)
mod = IRModule.from_expr(func)
executor = relay.create_executor("graph", mod)
evaluate = executor.evaluate()

x = np.arange(16).reshape(1, 4, 2, 2)
evaluate(x).shape

(1, 16, 1, 1)

## resize

In [96]:
from tvm.relay.dataflow_pattern import is_op, wildcard

conv2d_p = is_op("nn.conv2d")(wildcard(), wildcard())
bias_add_p = is_op("nn.bias_add")(conv2d_p, wildcard())
relu_p = is_op("nn.relu")(bias_add_p)

<function tvm.relay.op.image.image.resize2d(data, size, roi=None, layout='NCHW', method='linear', coordinate_transformation_mode='half_pixel', rounding_method='', cubic_alpha=-0.5, cubic_exclude=0, extrapolation_value=0.0, out_dtype=None)>

In [98]:
data = relay.var("data", shape=(1, 128, 20, 20), dtype="float32")
size = (40, 40)
func = relay.op.image.resize2d(data, size=size)
mod = IRModule.from_expr(func)
executor = relay.create_executor("graph", mod)
evaluate = executor.evaluate()

x = np.arange(1 * 128 * 20 * 20).reshape(1, 128, 20, 20)
y = evaluate(x)
y.shape

(1, 128, 40, 40)

In [102]:
mod

#[version = "0.0.5"]
def @main(%data: Tensor[(1, 128, 20, 20), float32]) {
  image.resize2d(%data, size=[40, 40], roi=[0f, 0f, 0f, 0f], rounding_method="")
}