In [1]:
# !export PYTHONPATH="$TVM_HOME/python"
# !echo $PYTHONPATH

# Define the Model with the relax frontend

In [2]:
import tvm
from tvm import relax
from tvm.relax.frontend import nn


class MLPModel(nn.Module):
    def __init__(self):
        super(MLPModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)
        self.softmax = nn.softmax

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

In [3]:
mod, param_spec = MLPModel().export_tvm(
    spec={"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
)
mod.show()

# Lower relax to TensorIR
In the "zero" pipeline, 5 passes are applied:
- LegalizeOps() : "Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs."
- AnnotateTIROpPattern() : "Annotate Op Pattern Kind for TIR functions."
- FoldConstant() : "Fold constant expressions within dataflow blocks."
- FuseOps() : This pass groups bindings in a dataflow block of Relax functions and generate a new grouped Relax function for each group, according to the fusion algorithm described in the pass implementation. By grouping bindings into new Relax functions, we substitute the bindings in the function being manipulated into function calls to the new grouped function. A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function.
- FuseTIR() : Fuse primitive relax function into a larger TIR function if possible

In [4]:
pipe0 = relax.get_pipeline("zero")(mod)
pipe0

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul1_add1(relu: T.Buffer((T.int64(1), T.int64(256)), "float32"), permute_dims1: T.Buffer((T.int64(256), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(10)))
        for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matm

In [5]:
import numpy as np

target = tvm.target.Target("llvm")
ex = relax.build(mod, target)
device = tvm.cpu()
vm = relax.VirtualMachine(ex, device)
data = np.random.rand(1, 784).astype("float32")
tvm_data = tvm.nd.array(data, device=device)
params = [np.random.rand(*param.shape).astype("float32") for _, param in param_spec]
params = [tvm.nd.array(param, device=device) for param in params]
print(vm["forward"](tvm_data, *params).numpy())

[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]]
