In [30]:
# This is needed for deferring annotation parsing in TVMScript
from __future__ import annotations
import numpy as np
import tvm
from tvm import relax
from tvm import relay
from tvm.ir.module import IRModule
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

In [2]:
from tvm import te
from tvm import topi

In [27]:
from tvm import relax as rx, tir

In [61]:
import torch
from torch import nn
from torch import fx
from torch.nn import functional as F

from d2l import torch as d2l



编写tir函数时，注意：
- bb.emit
- bb.emit_te

线性层Linear
- dense+add

In [34]:
def te_linear(X: te.Tensor, W: te.Tensor, B: te.Tensor, Z: te.Tensor) -> te.Tensor:
    Y = topi.nn.matmul(X, W, bias=B)
    res = topi.add(Y, Z)
    return res

def te_linear_func():
    b = te.var("batch") # batch
    indim = te.var("indim") # indim
    outdim = te.var("outdim") # outdim
    X = te.placeholder(shape=(b, indim), name="x", dtype="float32")
    W = te.placeholder(shape=(outdim, indim), name="w", dtype="float32")
    B = te.placeholder(shape=(outdim, ), name="b", dtype="float32")
    Z = te.placeholder(shape=(b, outdim), name="z", dtype="float32")
    Y = topi.nn.matmul(X, W, bias=B)
    res = topi.add(Y, Z)
    return te.create_prim_func([X, W, B, Z, res]).with_attr("global_symbol", "linear")

def linear(X: R.Tensor) -> R.Tensor:
    res = te_linear_func()
    return res

relu, 动态shape：

In [35]:
def relu_func():
    m = te.var("m")
    n = te.var("n")
    X = te.placeholder(shape=(m, n), name="x", dtype="float32")
    Y = topi.nn.relu(X)
    return te.create_prim_func([X, Y]).with_attr("global_symbol", "relu")

FFN, 前馈神经网络

In [42]:
def PositiveWiseFFN():
    bb = relax.BlockBuilder()
    # bb.add_func(te_linear_func(), "linear")
    # bb.add_func(relu_func(), "relu")
    b = tir.Var("b", "int64") 
    m, n, d = tir.Var("n", "int64"), tir.Var("m", "int64"), tir.Var("d", "int64")
    # 动态shape anno
    type_anno = rx.DynTensorType(2, "float32")
    type_anno1 = rx.DynTensorType(1, "float32")
    x = rx.Var("x", [b, m], type_anno)
    w0 = rx.Var("w0", [n, m], type_anno)
    b0 = rx.Var("b0", [n, ], type_anno1)
    z0 = rx.Var("z0", [b, n], type_anno)

    w1 = rx.Var("w1", [d, n], type_anno)
    b1 = rx.Var("b1", [d, ], type_anno1)
    z1 = rx.Var("z1", [b, d], type_anno)
    fn_inputs = [x, w0, b0, z0, w1, b1, z1]
    fn_output = None
    with bb.function("positivewiseffn"):
        with bb.dataflow():
            lv0 = bb.emit_te(te_linear, x, w0, b0, z0)
            lv1 = bb.emit_te(topi.nn.relu, lv0)
            output = bb.emit_te(te_linear, lv1, w1, b1, z1)
            fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

In [44]:
PositiveWiseFFN().show()

残差连接和层规范化

In [86]:
def te_layernorm(X: te.Tensor, gamma: te.Tensor, beta: te.Tensor):
    out = topi.nn.layer_norm(X, gamma, beta, axis=[-1])
    return out

In [90]:
A = te.placeholder(shape=(128, 128), name="A", dtype="float32")
gamma = te.placeholder(shape=(128, ), name="Gamma", dtype="float32")
beta = te.placeholder(shape=(128, ), name="Beta", dtype="float32")

In [None]:
def AddNorm():
    bb = relax.BlockBuilder()
    m, n, k = tir.Var("n", "int64"), tir.Var("m", "int64"), tir.Var("k", "int64")
    type_anno = rx.DynTensorType(3, "float32")
    
    fn_inputs = []
    fn_output = None
    with bb.function("AddNorm"):
        with bb.dataflow():
            lv0 = bb.emit(relax.nn.dropout(X))
            lv1 = bb.emit_te(topi.add, lv0, X)
            output = bb.emit_te(te_layernorm, lv1, gamma, beta)
            fn_output = bb.emit_output(output)
        bb.emit_func_output(fn_output, fn_inputs)
    return bb.get()

In [91]:
# pytorch的layernorm调用示例
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
layer_norm = nn.LayerNorm(embedding_dim)
print(layer_norm.weight.shape, layer_norm.bias.shape)
print(layer_norm(embedding).shape)

torch.Size([10]) torch.Size([10])
torch.Size([20, 5, 10])


`bb.call_te`：根据te函数生成一个调用节点
- 该函数将来自relax表达式的参数转换为tensor
- 回调函数应该返回一个te tensor或者te tensor的列表, 参考emit_te的例子

返回：
- ret: tvm.relax.Call, 返回新创建的Call node节点

`bb.emit_te`

In [None]:
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
type_anno = rx.DynTensorType(2, "float32")
x = rx.Var("x", [n, m], type_anno)
y = rx.Var("y", [n, m], type_anno)

def te_func(args, args_dict, msg):
    A = args[0]
    B = args_dict["B"]
    return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])

with bb.function([x, y], "rx_func"):
    out = bb.emit_te(te_func, [x], {"B": y}, msg="hello")
    bb.emit_func_output(out)