In [2]:
# This is needed for deferring annotation parsing in TVMScript
from __future__ import annotations
import numpy as np
import tvm
from tvm import relax
from tvm import relay
from tvm.ir.module import IRModule
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

In [3]:
from tvm import te
from tvm import topi

In [4]:
from tvm import relax as rx, tir

In [5]:
import torch
from torch import nn
from torch import fx
from torch.nn import functional as F

from d2l import torch as d2l

In [6]:
import math

编写tir函数时，注意：
- bb.emit
- bb.emit_te

In [None]:
ffn_params = {
    "w0": None, "b0": None,
    "w1": None, "b1": None
}
addnorm_params = {
    ""
}

线性层Linear
- dense+add

In [26]:
def te_linear(X: te.Tensor, W: te.Tensor, B: te.Tensor, Z: te.Tensor) -> te.Tensor:
    Y = topi.nn.matmul(X, W, bias=B)
    res = topi.add(Y, Z)
    return res

def te_linear_func():
    b = te.var("batch") # batch
    indim = te.var("indim") # indim
    outdim = te.var("outdim") # outdim
    X = te.placeholder(shape=(b, indim), name="x", dtype="float32")
    W = te.placeholder(shape=(outdim, indim), name="w", dtype="float32")
    B = te.placeholder(shape=(outdim, ), name="b", dtype="float32")
    Z = te.placeholder(shape=(b, outdim), name="z", dtype="float32")
    Y = topi.nn.matmul(X, W, bias=B)
    res = topi.add(Y, Z)
    return te.create_prim_func([X, W, B, Z, res]).with_attr("global_symbol", "linear")

def linear(X: R.Tensor) -> R.Tensor:
    res = te_linear_func()
    return res

relu, 动态shape：

In [27]:
def relu_func():
    m = te.var("m")
    n = te.var("n")
    X = te.placeholder(shape=(m, n), name="x", dtype="float32")
    Y = topi.nn.relu(X)
    return te.create_prim_func([X, Y]).with_attr("global_symbol", "relu")

FFN, 前馈神经网络

In [32]:
def PositiveWiseFFN():
    bb = relax.BlockBuilder()
    # bb.add_func(te_linear_func(), "linear")
    # bb.add_func(relu_func(), "relu")
    b = tir.Var("b", "int64") 
    m, n, d = tir.Var("n", "int64"), tir.Var("m", "int64"), tir.Var("d", "int64")
    # 动态shape anno
    type_anno = rx.DynTensorType(2, "float32")
    type_anno1 = rx.DynTensorType(1, "float32")
    x = rx.Var("x", [b, m], type_anno)
    w0 = rx.Var("w0", [n, m], type_anno)
    b0 = rx.Var("b0", [n, ], type_anno1)
    z0 = rx.Var("z0", [b, n], type_anno)

    w1 = rx.Var("w1", [d, n], type_anno)
    b1 = rx.Var("b1", [d, ], type_anno1)
    z1 = rx.Var("z1", [b, d], type_anno)
    fn_inputs = [x, w0, b0, z0, w1, b1, z1]
    fn_output = None
    with bb.function("positivewiseffn"):
        with bb.dataflow():
            lv0 = bb.emit_te(te_linear, x, w0, b0, z0)
            lv1 = bb.emit_te(topi.nn.relu, lv0)
            output = bb.emit_te(te_linear, lv1, w1, b1, z1)
            fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

In [33]:
PositiveWiseFFN().show()

残差连接和层规范化

In [33]:
def te_layernorm(X: te.Tensor, gamma: te.Tensor, beta: te.Tensor):
    out = topi.nn.layer_norm(X, gamma, beta, axis=[-1])
    return out

In [90]:
A = te.placeholder(shape=(128, 128), name="A", dtype="float32")
gamma = te.placeholder(shape=(128, ), name="Gamma", dtype="float32")
beta = te.placeholder(shape=(128, ), name="Beta", dtype="float32")

In [34]:
def addnorm():
    bb = relax.BlockBuilder()
    b = tir.Var("b", "int64")
    m, n, k = tir.Var("m", "int64"), tir.Var("n", "int64"), tir.Var("k", "int64")    
    
    type_anno = rx.DynTensorType(3, "float32")
    type_anno1 = rx.DynTensorType(1, "float32")
    
    x = rx.Var("x", [b, m, n], type_anno)
    gamma = rx.Var("gamma", [n], type_anno1)
    beta = rx.Var("beta", [n], type_anno1)
    fn_inputs = [x, gamma, beta]
    fn_output = None
    with bb.function("addnorm"):
        with bb.dataflow():
            lv0 = bb.emit(R.TupleGetItem(relax.nn.dropout(x), 1))
            lv1 = bb.emit_te(topi.add, lv0, x)
            output = bb.emit_te(te_layernorm, lv1, gamma, beta)
            fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

In [35]:
addnorm().show()

transpose

In [87]:
def transpose_qkv(num_heads=2):
    bb = relax.BlockBuilder()
    b = T.Var("b", "int64")
    m, n = T.Var("m", "int64"), T.Var("n", "int64")

    type_anno = rx.DynTensorType(3, "float32")
    x = rx.Var("x", [b, m, n], type_anno)

    fn_inputs = [x]
    fn_output = None
    with bb.function("transpose_qkv"):
        with bb.dataflow():
            lv0 = bb.emit(relax.op.reshape(x, (b, m, num_heads, -1)))
            lv1 = bb.emit(relax.op.transpose(lv0, (0, 2, 1, 3)))
            output = bb.emit(relax.op.reshape(lv1, (-1, lv1.shape[2], lv1.shape[3])))
            fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

def transpose_output(num_heads=2):
    bb = relax.BlockBuilder()
    bh = T.Var("bh", "int64")
    m, ndh = T.Var("m", "int64"), T.Var("ndh", "int64")
    
    type_anno = rx.DynTensorType(3, "float32")    
    x = rx.Var("x", [bh, m, ndh], type_anno)

    fn_inputs = [x]
    fn_output = None
    with bb.function("transpose_output"):
        with bb.dataflow():
            lv0 = bb.emit(relax.op.reshape(x, (-1, num_heads, m, ndh)))
            lv1 = bb.emit(relax.op.transpose(lv0, (0, 2, 1, 3)))
            output = bb.emit(relax.op.reshape(lv1, (lv1.shape[0], lv1.shape[1], -1)))
            fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

In [88]:
transpose_qkv(num_heads=2).show()

In [89]:
transpose_output(num_heads=2).show()

masked_softmax

In [44]:
def masked_softmax_dyn():
    bb = relax.BlockBuilder()
    b = T.Var("b", "int64")
    m, n = T.Var("m", "int64"), T.Var("n", "int64")
    k = T.Var("k", "int64")

    type_anno = rx.DynTensorType(3, "float32")
    type_anno1 = rx.DynTensorType(2, "float32")
    x = rx.Var("x", [b, m, n], type_anno)
    valid_lens = rx.Var("valid_lens", [n], type_anno1)
    fn_inputs = [x, valid_lens]
    fn_output = None
    with bb.function("masked_softmax"):
        with bb.dataflow():
            lv0 = bb.emit(relax.op.reshape(x, (-1, n)))
            lv1 = bb.emit_te(topi.sequence_mask, lv0, valid_lens, -1e6, 1)
            output = bb.emit(relax.nn.softmax(lv1, axis=-1))
            fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

def masked_softmax(x_shape: R.Shape, valid_shape: R.Shape):
    bb = relax.BlockBuilder()

    x = rx.Var("x", x_shape, rx.DynTensorType(len(x_shape), "float32"))
    valid_lens = rx.Var("valid_lens", valid_shape, rx.DynTensorType(len(valid_shape), "int64"))

    fn_inputs = [x, valid_lens]
    fn_output = None
    with bb.function("masked_softmax"):
        with bb.dataflow():
            lv0 = bb.emit(relax.op.reshape(x, (-1, x_shape[2])))
            lv1 = bb.emit_te(topi.sequence_mask, lv0, valid_lens, -1e6, 1)
            output = bb.emit(relax.nn.softmax(R.reshape(lv1, x_shape), axis=-1))
            fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

# def rx_masked_softmax(x: rx.Expr, valid_lens: rx.Expr)-> rx.Expr:
#     print(x.shape)

def rx_masked_softmax(x: rx.Expr, valid_lens: rx.Expr)-> rx.Expr:
    # print(x.shape)
    # x_reshape = R.reshape(x, (-1, x.shape[2]))
    # te_x, te_valid = rx.te_tensor(x_reshape, "te_x"), rx.te_tensor(valid_lens, "te_valid")
    # R.match_shape(te_x, x_reshape.shape)
    # R.match_shape(te_valid, valid_lens.shape)
    # seq_m = topi.sequence_mask(te_x, valid_length=te_valid, mask_value=-1e6, axis=1)
    # return R.nn.softmax(seq_m, axis=-1)
    return R.nn.softmax(x, axis=-1)

In [45]:
masked_softmax(x_shape=(10, 20, 30), valid_shape=(30, )).show()

In [17]:
rx.VarBinding(var=rx.Var("masked_softmax"), value=masked_softmax()["masked_softmax"])

relax.expr.VarBinding(0x646f0b0)

In [67]:
A = rx.Var("A", shape_annotation=(10, 20, 30), type_annotation=relax.DynTensorType(3, "float32"))
B = rx.Var("B", shape_annotation=(10, 10), type_annotation=relax.DynTensorType(1, "float32"))

In [68]:
rx_masked_softmax(A, B)

(10, 20, 30) (10, 10)


In [30]:
masked_softmax().show()

In [12]:
masked_softmax()["masked_softmax"]

relax.expr.Function(0x7bb53d0)

In [21]:
bb = relax.BlockBuilder()

In [22]:
bb.add_func(masked_softmax()["masked_softmax"], func_name="masked_softmax")
bb.add_func(masked_softmax()["sequence_mask"], func_name="sequence_mask")

GlobalVar(sequence_mask)

In [23]:
bb.get().show()

bmm

In [80]:
A = te.placeholder(name="A", shape=(10, 40, 30), dtype="float32")
B = te.placeholder(name="B", shape=(10, 30, 20), dtype="float32")
C = topi.nn.batch_matmul(A, B, transpose_b=False)
te.create_prim_func([A, B, C]).show()

In [132]:
def te_bmm(A: te.Tensor, B: te.Tensor) -> te.Tensor:
    assert(len(A.shape) == 3)
    assert(len(B.shape) == 3)
    assert(A.shape[-1] == B.shape[-2])
    return topi.nn.batch_matmul(A, B, transpose_b=False)

def bmm_t(A_shape: R.Shape, B_shape: R.Shape):
    bb = relax.BlockBuilder()
    # x = rx.Var("x", (10, 20, 30), R.Tensor)
    # y = rx.Var("y", (10, 20, 30), R.Tensor)
    x = rx.Var("x", A_shape, R.Tensor)
    y = rx.Var("y", B_shape, R.Tensor)
    fn_inputs = [x, y]
    fn_output = None
    with bb.function("bmm_t"):
        with bb.dataflow():
            lv0 = bb.emit_te(te_bmm, x, y)
            fn_output = bb.emit_output(lv0)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

In [134]:
# bmm_t((10, 20, 30), (10, 30, 40)).show()

点积注意力机制

In [46]:
def dotproduct_attention_dynamic():
    bb = relax.BlockBuilder()
    b = T.Var("b", "int64")
    q, d = T.Var("q", "int64"), T.Var("d", "int64")
    k = T.Var("k", "int64")

    type_anno = rx.DynTensorType(3, "float32")
    type_anno2 = rx.DynTensorType(3, "float32")
    typa_anno3 = rx.DynTensorType(1, "int64")

    queries = rx.Var("queries", [b, q, d], type_anno)
    keys = rx.Var("keys", [b, k, d], type_anno2)
    values = rx.Var("values", [b, k, d], type_anno2)
    valid_lens = rx.Var("valid_lens", [b], typa_anno3)
    
    fn_inputs = [queries, keys, values, valid_lens]
    fn_output = None
    with bb.function("dotproduct_attention"):
        with bb.dataflow():
            k_transpose = bb.emit(relax.op.transpose(keys, (0,2,1)))
            lv0 = bb.emit_te(topi.nn.batch_matmul, rx.te_tensor(queries), rx.te_tensor(k_transpose))
            scores = R.divide(lv0, math.sqrt(d))
            attention_weights = bb.emit(masked_softmax()(scores, valid_lens))
            output = bb.emit_te(topi.nn.batch_matmul, relax.nn.dropout(attention_weights), values)
            fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

def dotproduct_attention(queries_shape: R.Shape, keys_shape: R.Shape, 
            values_shape: R.Shape, valid_lens_shape: R.Shape):
    bb = relax.BlockBuilder()

    # add functions
    mask_softmax_mod = masked_softmax(x_shape=(10, 20, 40), valid_shape=valid_lens_shape)
    seq_mask = bb.add_func(mask_softmax_mod["sequence_mask"], "sequence_mask")
    rx_masked_softmax = bb.add_func(mask_softmax_mod["masked_softmax"], "masked_softmax")

    # relax args type
    queries = rx.Var("queries", queries_shape, rx.DynTensorType(len(queries_shape), "float32"))
    keys = rx.Var("keys", keys_shape, rx.DynTensorType(len(keys_shape), "float32"))
    values = rx.Var("values", values_shape, rx.DynTensorType(len(values_shape), "float32"))
    valid_lens = rx.Var("valid_lens", valid_lens_shape, rx.DynTensorType(len(valid_lens_shape), "int64"))
    
    d = queries.shape[2]
    # relax func infos 
    fn_inputs = [queries, keys, values, valid_lens]
    fn_output = None

    with bb.function("dotproduct_attention"):
        with bb.dataflow():
            lv0 = bb.emit_te(topi.nn.batch_matmul, queries, keys)
            scores = bb.emit(R.divide(lv0, rx.const(math.sqrt(int(d)))))
            attention_weights = bb.emit(rx_masked_softmax(scores, valid_lens))
            w_dp = bb.emit(R.TupleGetItem(R.nn.dropout(attention_weights), 1))
            output = bb.emit_te(topi.nn.batch_matmul, w_dp, values, transpose_b=False)
            fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

In [47]:
dotproduct_attention(queries_shape=(10, 20, 30), 
            keys_shape=(10, 40, 30),
            values_shape=(10, 40, 100), 
            valid_lens_shape=(40, )).show()

(10, 20, 40) (40,)
(10, 20, 40)
(10, 20, 40) (10, 40, 100)


`bb.call_te`：根据te函数生成一个调用节点
- 该函数将来自relax表达式的参数转换为tensor
- 回调函数应该返回一个te tensor或者te tensor的列表, 参考emit_te的例子

返回：
- ret: tvm.relax.Call, 返回新创建的Call node节点

`bb.emit_te`

In [None]:
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
type_anno = rx.DynTensorType(2, "float32")
x = rx.Var("x", [n, m], type_anno)
y = rx.Var("y", [n, m], type_anno)

def te_func(args, args_dict, msg):
    A = args[0]
    B = args_dict["B"]
    return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])

with bb.function([x, y], "rx_func"):
    out = bb.emit_te(te_func, [x], {"B": y}, msg="hello")
    bb.emit_func_output(out)