In [None]:
import tvm

import numpy as np

import tvm.te as te
import tvm.relax as rx
import tvm.tir as tir

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

from tvm.relax.binding_rewrite import DataflowBlockRewrite
from tvm.relax.analysis import name_to_binding


def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)


import tvm.testing
import tvm.relax.testing.vm

In [None]:
def test_vm_build():
    @tvm.script.ir_module
    class test_vm_build_mod:
        @R.function
        def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")):
            # NOTE `test.vm.identity` is registered in `tvm/relax/testing/vm.py`
            _ = R.call_pure_packed(
                "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))
            )
            return y

    mod = test_vm_build_mod
    target = tvm.target.Target("llvm", host="llvm")
    ex = tvm.relax.build(mod, target, exec_mode="bytecode")
    vm = tvm.relax.VirtualMachine(ex, tvm.cpu())

    np1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))
    np2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))

    y = vm["foo"](np1, np2)
    print(f"y.numpy(): \n{y.numpy()}\n")
    tvm.testing.assert_allclose(np2.numpy(), np1.numpy(), rtol=1e-7, atol=1e-7)

    # matmul mod
    @tvm.script.ir_module
    class matmul_mod:
        @R.function
        def matmul(x: R.Tensor((64, 64), "float32"), y: R.Tensor((64, 64), "float32")):
            z = R.matmul(x, y)
            return z

    mod2 = matmul_mod
    target = tvm.target.Target("llvm", host="llvm")
    ex = tvm.relax.build(mod2, target, exec_mode="compiled")
    # we can aslo use `tvm.compile` to build the module
    # ex = tvm.compile(mod2, target=target)

    # BUG @benkangpeng The content printed below is meaningless.
    # ex: VMExecutable
    # print(ex.stats())
    # print(ex.as_python())
    # print(ex.as_text())

    vm = tvm.relax.VirtualMachine(ex, tvm.cpu())

    np1 = np.random.rand(64, 64).astype(np.float32)
    np2 = np.random.rand(64, 64).astype(np.float32)

    np3 = np.matmul(np1, np2)
    res = vm["matmul"](tvm.nd.array(np1), tvm.nd.array(np2))
    tvm.testing.assert_allclose(res.numpy(), np3, rtol=1e-5, atol=1e-5)
    print(f"res.numpy(): \n{res.numpy()}\n")

    return ex


ex = test_vm_build()

y.numpy(): 
[[0.10472399 0.66643685 0.35652795 0.4084878 ]
 [0.41518262 0.6130509  0.6014363  0.10648939]
 [0.816477   0.5540616  0.7546119  0.37130144]]

res.numpy(): 
[[16.189398 17.059381 17.948286 ... 17.370955 17.224758 16.96119 ]
 [15.189768 16.005981 17.682283 ... 16.872387 17.490295 17.65356 ]
 [16.778374 14.794839 17.757568 ... 17.0416   17.32591  15.865465]
 ...
 [14.329997 14.277385 16.929363 ... 16.507309 16.752064 16.218044]
 [18.177542 16.748013 19.14382  ... 18.983114 19.113914 18.340569]
 [15.993771 14.182335 16.262959 ... 15.553992 16.757072 17.118176]]



In [None]:
# Get LLVM IR code
code = ex.mod.imported_modules[0].get_source("ll")  # Or get_source("")
print(code[:1000])  # Print the first 1000 characters

; ModuleID = 'TVMMod'
source_filename = "TVMMod"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

%0 = type { double }

@__tvm_module_ctx = linkonce dllexport local_unnamed_addr global ptr null, align 8
@__TVMFuncCall = linkonce dllexport local_unnamed_addr global ptr null, align 8
@__TVMBackendGetFuncFromEnv = linkonce dllexport local_unnamed_addr global ptr null, align 8
@__TVMAPISetLastError = linkonce dllexport local_unnamed_addr global ptr null, align 8
@.str = private constant [66 x i8] c"Assert fail: num_args == 4, __vmtir__matmul: num_args should be 4\00", align 1
@.str.1 = private constant [84 x i8] c"Assert fail: not T.isnullptr(args), __vmtir__matmul: TVMValue* arg pointer was NULL\00", align 1
@.str.2 = private constant [86 x i8] c"Assert fail: not T.isnullptr(arg_type_ids), __vmtir__matmul: int* type_codes was NULL\00", align 1
@.str.3 = private constant [141 x i8] c"Assert fai

In [None]:
# Get Assembly code
code = ex.mod.imported_modules[0].get_source("asm")  # Or get_source("s")
print(code[:400])  # Print the first 400 characters

	.text
	.file	"TVMMod"
	.globl	__vmtir__matmul
	.p2align	4
	.type	__vmtir__matmul,@function
__vmtir__matmul:
.Lfunc_begin0:
	.file	1 "." "IRModule.CodeGenLLVM"
	.loc	1 0 0
	.cfi_startproc
	subq	$120, %rsp
	.cfi_def_cfa_offset 128
.Ltmp0:
	cmpl	$4, %edx
	jne	.LBB0_1
.Ltmp1:
	testq	%rdi, %rdi
	je	.LBB0_4
.Ltmp2:
	testq	%rsi, %rsi
	je	.LBB0_6
.Ltmp3:
	movl	(%rsi), %eax
.Ltmp4:
	cmpl	$13, %eax
	ja	.LB


In [None]:
def test_vmcodegen():
    @tvm.script.ir_module
    class test_vmcodegen_mod:
        @T.prim_func
        def matmul(
            x: T.Buffer((16, 32), "float32"),
            y: T.Buffer((32, 64), "float32"),
            z: T.Buffer((16, 64), "float32"),
        ):
            T.func_attr({"global_symbol": "matmul"})
            for i, j, k in T.grid(16, 64, 32):
                with T.block("T_matmul"):
                    i_1, j_1, k_1 = T.axis.remap("SSR", [i, j, k])
                    with T.init():
                        z[i_1, j_1] = T.float32(0)
                    z[i_1, j_1] = z[i_1, j_1] + x[i_1, k_1] * y[k_1, j_1]

    builder = tvm.relax.ExecBuilder()
    mod = tvm.relax.vm_build._vmcodegen(
        builder, test_vmcodegen_mod, exec_mode="compiled"
    )
    showmod(mod)


test_vmcodegen()

#### _vmcodegen

In [None]:
builder = tvm.relax.ExecBuilder()


@tvm.script.ir_module
class Module:
    @R.function
    def add(x: R.Tensor((3, 4), "float32")):
        return x


mod = Module
showmod(mod)

mod = tvm.relax.vm_build._vmcodegen(builder, mod, exec_mode="compiled")
showmod(mod)

#### VMExecutable

In [None]:
import numpy as np

dtype = "float32"


@I.ir_module
class mm_relu:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        M, K, N = T.int64(), T.int64(), T.int64()

        A_Buf = T.match_buffer(A, [M, K], dtype)
        B_Buf = T.match_buffer(B, [K, N], dtype)
        C_Buf = T.match_buffer(C, [M, N], dtype)

        Y_Buf = T.alloc_buffer(shape=[M, N], dtype=dtype)

        for i, j, k in T.grid(M, N, K):
            with T.block("Y_Buf"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y_Buf[vi, vj] = T.cast(0.0, dtype)
                Y_Buf[vi, vj] = Y_Buf[vi, vj] + A_Buf[vi, vk] * B_Buf[vk, vj]

        for i, j in T.grid(M, N):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C_Buf[vi, vj] = T.max(Y_Buf[vi, vj], T.cast(0.0, dtype))


mod = mm_relu
showmod(mod)


# 构建 Relax 虚拟机模块
ex = tvm.relax.build(mod, target="llvm")
vmexecutable = tvm.relax.VMExecutable(ex)

print(vmexecutable.stats())

# 或者直接创建虚拟机
vm = tvm.relax.VirtualMachine(ex, tvm.cpu())

Relax VM executable statistics:
  Constant pool (# 0): []
  Globals (#0): []

